summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSimon Rettberg2021-03-31 12:31:13 +0200
committerSimon Rettberg2021-04-14 13:17:59 +0200
commit79e23317fff58c2c7eaace6df781bccc08775f28 (patch)
tree589c4b708484768b0524bf11ac6089bf49de06be
parent[KERNEL] Fix CMD name in debug messages (diff)
downloaddnbd3-79e23317fff58c2c7eaace6df781bccc08775f28.tar.gz
dnbd3-79e23317fff58c2c7eaace6df781bccc08775f28.tar.xz
dnbd3-79e23317fff58c2c7eaace6df781bccc08775f28.zip
[KERNEL] Deduplicate code, clean up, split into functions
-rw-r--r--src/kernel/net.c732
-rw-r--r--src/kernel/net.h9
2 files changed, 339 insertions, 402 deletions
diff --git a/src/kernel/net.c b/src/kernel/net.c
index 3c19f8d..b07e8dc 100644
--- a/src/kernel/net.c
+++ b/src/kernel/net.c
@@ -39,19 +39,28 @@
#endif
#ifdef CONFIG_DEBUG_DRIVER
-#define ASSERT(x) \
- do { \
- if (!(x)) { \
- printk(KERN_EMERG "assertion failed %s: %d: %s\n", __FILE__, __LINE__, #x); \
- BUG(); \
- } \
+#define ASSERT(x) \
+ do { \
+ if (!(x)) { \
+ printk(KERN_EMERG "assertion failed %s: %d: %s\n", __FILE__, __LINE__, #x); \
+ BUG(); \
+ } \
} while (0)
#else
-#define ASSERT(x) \
- do { \
+#define ASSERT(x) \
+ do { \
} while (0)
#endif
+#define init_msghdr(h) \
+ do { \
+ h.msg_name = NULL; \
+ h.msg_namelen = 0; \
+ h.msg_control = NULL; \
+ h.msg_controllen = 0; \
+ h.msg_flags = MSG_WAITALL | MSG_NOSIGNAL; \
+ } while (0)
+
// cmd_flags and cmd_type are merged into cmd_flags now
#if REQ_FLAG_BITS > 24
#error "Fix CMD bitshift"
@@ -63,18 +72,23 @@
#define DNBD3_DEV_READ REQ_OP_READ
#define DNBD3_REQ_OP_SPECIAL REQ_OP_DRV_IN
+#define dnbd3_dev_dbg_host(dev, host, fmt, ...) \
+ dev_dbg(dnbd3_device_to_dev(dev), "(%pISpc): " fmt, (host), ##__VA_ARGS__)
+#define dnbd3_dev_err_host(dev, host, fmt, ...) \
+ dev_err(dnbd3_device_to_dev(dev), "(%pISpc): " fmt, (host), ##__VA_ARGS__)
+
#define dnbd3_dev_dbg_host_cur(dev, fmt, ...) \
- dev_dbg(dnbd3_device_to_dev(dev), "(%pISpc): " fmt, &(dev)->cur_server.host, ##__VA_ARGS__)
+ dnbd3_dev_dbg_host(dev, &(dev)->cur_server.host, fmt, ##__VA_ARGS__)
#define dnbd3_dev_err_host_cur(dev, fmt, ...) \
- dev_err(dnbd3_device_to_dev(dev), "(%pISpc): " fmt, &(dev)->cur_server.host, ##__VA_ARGS__)
-
-#define dnbd3_dev_dbg_host_alt(dev, fmt, ...) \
- dev_dbg(dnbd3_device_to_dev(dev), "(%pISpc): " fmt, &(dev)->alt_servers[i].host, ##__VA_ARGS__)
-#define dnbd3_dev_err_host_alt(dev, fmt, ...) \
- dev_err(dnbd3_device_to_dev(dev), "(%pISpc): " fmt, &(dev)->alt_servers[i].host, ##__VA_ARGS__)
+ dnbd3_dev_err_host(dev, &(dev)->cur_server.host, fmt, ##__VA_ARGS__)
static struct socket *dnbd3_connect(dnbd3_device_t *dev, struct sockaddr_storage *addr);
+static int dnbd3_execute_handshake(dnbd3_device_t *dev, struct socket *sock,
+ struct sockaddr_storage *addr, uint16_t *remote_version);
+
+int dnbd3_request_test_block(dnbd3_device_t *dev, struct sockaddr_storage *addr, struct socket *sock);
+
static void dnbd3_net_heartbeat(struct timer_list *arg)
{
dnbd3_device_t *dev = (dnbd3_device_t *)container_of(arg, dnbd3_device_t, hb_timer);
@@ -122,23 +136,10 @@ static void dnbd3_net_heartbeat(struct timer_list *arg)
static int dnbd3_net_discover(void *data)
{
dnbd3_device_t *dev = data;
- struct sockaddr_in sin4;
- struct sockaddr_in6 sin6;
struct socket *sock, *best_sock = NULL;
-
- dnbd3_request_t dnbd3_request;
- dnbd3_reply_t dnbd3_reply;
dnbd3_alt_server_t *alt;
struct sockaddr_storage host_compare, best_server;
- struct msghdr msg;
- struct kvec iov[2];
-
- char *buf, *name;
- serialized_buffer_t *payload;
- uint64_t filesize;
- uint16_t rid;
uint16_t remote_version;
-
ktime_t start = 0, end = 0;
unsigned long rtt, best_rtt = 0;
unsigned long irqflags;
@@ -146,25 +147,9 @@ static int dnbd3_net_discover(void *data)
int turn = 0;
int ready = 0, do_change = 0;
char check_order[NUMBER_SERVERS];
- int mlen;
struct request *last_request = (struct request *)123, *cur_request = (struct request *)456;
- memset(&sin4, 0, sizeof(sin4));
- memset(&sin6, 0, sizeof(sin6));
-
- init_msghdr(msg);
-
- BUILD_BUG_ON(sizeof(serialized_buffer_t) > DNBD3_BLOCK_SIZE);
-
- buf = kmalloc(DNBD3_BLOCK_SIZE, GFP_KERNEL);
- if (!buf)
- return -ENOMEM;
-
- payload = (serialized_buffer_t *)buf; // Reuse this buffer to save kernel mem
-
- dnbd3_request.magic = dnbd3_packet_magic;
-
for (i = 0; i < NUMBER_SERVERS; ++i)
check_order[i] = i;
@@ -194,9 +179,9 @@ static int dnbd3_net_discover(void *data)
for (i = 0; i < isize; ++i) {
j = ((ktime_to_s(start) >> i) ^ (ktime_to_us(start) >> j)) % NUMBER_SERVERS;
if (j != i) {
- mlen = check_order[i];
+ int tmp = check_order[i];
check_order[i] = check_order[j];
- check_order[j] = mlen;
+ check_order[j] = tmp;
}
}
}
@@ -220,93 +205,13 @@ static int dnbd3_net_discover(void *data)
if (sock == NULL)
goto error;
- // Request filesize
- dnbd3_request.cmd = CMD_SELECT_IMAGE;
- iov[0].iov_base = &dnbd3_request;
- iov[0].iov_len = sizeof(dnbd3_request);
- serializer_reset_write(payload);
- serializer_put_uint16(payload, PROTOCOL_VERSION); // DNBD3 protocol version
- serializer_put_string(payload, dev->imgname); // image name
- serializer_put_uint16(payload, dev->rid); // revision id
- serializer_put_uint8(payload, 0); // are we a server? (no!)
- iov[1].iov_base = payload;
- dnbd3_request.size = iov[1].iov_len = serializer_get_written_length(payload);
- fixup_request(dnbd3_request);
- mlen = iov[1].iov_len + sizeof(dnbd3_request);
- if (kernel_sendmsg(sock, &msg, iov, 2, mlen) != mlen) {
- dnbd3_dev_err_host_alt(dev, "requesting image size failed\n");
+ if (!dnbd3_execute_handshake(dev, sock, &host_compare, &remote_version))
goto error;
- }
- // receive net reply
- iov[0].iov_base = &dnbd3_reply;
- iov[0].iov_len = sizeof(dnbd3_reply);
- if (kernel_recvmsg(sock, &msg, iov, 1, sizeof(dnbd3_reply), msg.msg_flags) !=
- sizeof(dnbd3_reply)) {
- dnbd3_dev_err_host_alt(dev, "receiving image size packet (header) failed (discover)\n");
- goto error;
- }
- fixup_reply(dnbd3_reply);
- if (dnbd3_reply.magic != dnbd3_packet_magic || dnbd3_reply.cmd != CMD_SELECT_IMAGE ||
- dnbd3_reply.size < 4) {
- dnbd3_dev_err_host_alt(dev,
- "content of image size packet (header) mismatched (discover)\n");
- goto error;
- }
-
- // receive data
- iov[0].iov_base = payload;
- iov[0].iov_len = dnbd3_reply.size;
- if (kernel_recvmsg(sock, &msg, iov, 1, dnbd3_reply.size, msg.msg_flags) != dnbd3_reply.size) {
- dnbd3_dev_err_host_alt(dev,
- "receiving image size packet (payload) failed (discover)\n");
- goto error;
- }
- serializer_reset_read(payload, dnbd3_reply.size);
-
- remote_version = serializer_get_uint16(payload);
- if (remote_version < MIN_SUPPORTED_SERVER) {
- dnbd3_dev_err_host_alt(
- dev, "server version too old (client: %d, server: %d, min supported: %d)\n",
- (int)PROTOCOL_VERSION, (int)remote_version,
- (int)MIN_SUPPORTED_SERVER);
- goto error;
- }
-
- name = serializer_get_string(payload);
- if (name == NULL) {
- dnbd3_dev_err_host_alt(dev, "server did not supply an image name (discover)\n");
- goto error;
- }
-
- if (strcmp(name, dev->imgname) != 0) {
- dnbd3_dev_err_host_alt(
- dev,
- "image name does not match requested one (client: '%s', server: '%s') (discover)\n",
- dev->imgname, name);
- goto error;
- }
-
- rid = serializer_get_uint16(payload);
- if (rid != dev->rid) {
- dnbd3_dev_err_host_alt(
- dev, "server supplied wrong rid (client: '%d', server: '%d') (discover)\n",
- (int)dev->rid, (int)rid);
- goto error;
- }
-
- filesize = serializer_get_uint64(payload);
- if (filesize != dev->reported_size) {
- dnbd3_dev_err_host_alt(
- dev,
- "reported image size of %llu does not match expected value %llu (discover)\n",
- (unsigned long long)filesize, (unsigned long long)dev->reported_size);
- goto error;
- }
// panic mode, take first responding server
if (dev->panic) {
- dnbd3_dev_dbg_host_alt(dev, "panic mode, changing server ...\n");
+ dnbd3_dev_dbg_host(dev, &host_compare, "panic mode, changing to new server\n");
while (atomic_cmpxchg(&dev->connection_lock, 0, 1) != 0)
schedule();
@@ -317,7 +222,6 @@ static int dnbd3_net_discover(void *data)
sock_release(best_sock);
dev->better_sock = sock; // Pass over socket to take a shortcut in *_connect();
- kfree(buf);
put_task_struct(dev->thread_discover);
dev->thread_discover = NULL;
dnbd3_net_disconnect(dev);
@@ -331,48 +235,11 @@ static int dnbd3_net_discover(void *data)
atomic_set(&dev->connection_lock, 0);
}
- // Request block
- dnbd3_request.cmd = CMD_GET_BLOCK;
- // Do *NOT* pick a random block as it has proven to cause severe
- // cache thrashing on the server
- dnbd3_request.offset = 0;
- dnbd3_request.size = RTT_BLOCK_SIZE;
- fixup_request(dnbd3_request);
- iov[0].iov_base = &dnbd3_request;
- iov[0].iov_len = sizeof(dnbd3_request);
-
// start rtt measurement
start = ktime_get_real();
- if (kernel_sendmsg(sock, &msg, iov, 1, sizeof(dnbd3_request)) <= 0) {
- dnbd3_dev_err_host_alt(dev, "requesting test block failed (discover)\n");
- goto error;
- }
-
- // receive net reply
- iov[0].iov_base = &dnbd3_reply;
- iov[0].iov_len = sizeof(dnbd3_reply);
- if (kernel_recvmsg(sock, &msg, iov, 1, sizeof(dnbd3_reply), msg.msg_flags) !=
- sizeof(dnbd3_reply)) {
- dnbd3_dev_err_host_alt(dev, "receiving test block header packet failed (discover)\n");
- goto error;
- }
- fixup_reply(dnbd3_reply);
- if (dnbd3_reply.magic != dnbd3_packet_magic || dnbd3_reply.cmd != CMD_GET_BLOCK ||
- dnbd3_reply.size != RTT_BLOCK_SIZE) {
- dnbd3_dev_err_host_alt(
- dev, "unexpected reply to block request: cmd=%d, size=%d (discover)\n",
- (int)dnbd3_reply.cmd, (int)dnbd3_reply.size);
- goto error;
- }
-
- // receive data
- iov[0].iov_base = buf;
- iov[0].iov_len = RTT_BLOCK_SIZE;
- if (kernel_recvmsg(sock, &msg, iov, 1, dnbd3_reply.size, msg.msg_flags) != RTT_BLOCK_SIZE) {
- dnbd3_dev_err_host_alt(dev, "receiving test block payload failed (discover)\n");
+ if (!dnbd3_request_test_block(dev, &host_compare, sock))
goto error;
- }
end = ktime_get_real(); // end rtt measurement
@@ -481,7 +348,6 @@ error:
dev_info(dnbd3_device_to_dev(dev), "server %pISpc is faster (%lluµs vs. %lluµs)\n",
&best_server,
(unsigned long long)best_rtt, (unsigned long long)dev->cur_server.rtt);
- kfree(buf);
dev->better_sock = best_sock; // Take shortcut by continuing to use open connection
put_task_struct(dev->thread_discover);
dev->thread_discover = NULL;
@@ -508,7 +374,6 @@ error:
ready = 1;
}
- kfree(buf);
if (kthread_should_stop())
dev_dbg(dnbd3_device_to_dev(dev), "kthread %s terminated normally\n", __func__);
else
@@ -880,6 +745,7 @@ static struct socket *dnbd3_connect(dnbd3_device_t *dev, struct sockaddr_storage
ktime_t start;
int ret, connect_time_ms;
struct socket *sock;
+ int retries = 4;
if (sock_create_kern(&init_net, addr->ss_family, SOCK_STREAM, IPPROTO_TCP, &sock) < 0) {
dev_err(dnbd3_device_to_dev(dev), "couldn't create socket\n");
@@ -914,40 +780,296 @@ static struct socket *dnbd3_connect(dnbd3_device_t *dev, struct sockaddr_storage
connect_time_ms = SOCKET_TIMEOUT_CLIENT_DATA * 1000;
set_socket_timeouts(sock, connect_time_ms);
start = ktime_get_real();
- ret = kernel_connect(sock, (struct sockaddr *)addr, sizeof(*addr), 0);
- connect_time_ms = (int)ktime_ms_delta(ktime_get_real(), start);
- if (connect_time_ms > 2 * SOCKET_TIMEOUT_CLIENT_DATA * 1000) {
- /* Either I'm losing my mind or there was a specific build of kernel
- * 5.x where SO_RCVTIMEO didn't affect the connect call above, so
- * this function would hang for over a minute for unreachable hosts.
- * Leave in this debug check for twice the configured timeout
- */
- dev_dbg(dnbd3_device_to_dev(dev), "%pISpc connect call took %dms\n",
- addr, connect_time_ms);
- }
- if (ret != 0) {
- dev_dbg(dnbd3_device_to_dev(dev), "%pISpc connect failed (%d, blocked %dms)\n",
- addr, ret, connect_time_ms);
- goto error;
+ while (--retries > 0) {
+ ret = kernel_connect(sock, (struct sockaddr *)addr, sizeof(*addr), 0);
+ connect_time_ms = (int)ktime_ms_delta(ktime_get_real(), start);
+ if (connect_time_ms > 2 * SOCKET_TIMEOUT_CLIENT_DATA * 1000) {
+ /* Either I'm losing my mind or there was a specific build of kernel
+ * 5.x where SO_RCVTIMEO didn't affect the connect call above, so
+ * this function would hang for over a minute for unreachable hosts.
+ * Leave in this debug check for twice the configured timeout
+ */
+ dev_dbg(dnbd3_device_to_dev(dev), "%pISpc connect call took %dms\n",
+ addr, connect_time_ms);
+ }
+ if (ret != 0) {
+ if (ret == -EINTR)
+ continue;
+ dev_dbg(dnbd3_device_to_dev(dev), "%pISpc connect failed (%d, blocked %dms)\n",
+ addr, ret, connect_time_ms);
+ goto error;
+ }
+ return sock;
}
- return sock;
error:
sock_release(sock);
return NULL;
}
+/**
+ * Execute protocol handshake on a newly connected socket.
+ * If this is the initial connection to any server, ie. we're being called
+ * through the initial ioctl() to open a device, we'll store the rid, filesize
+ * etc. in the dev struct., otherwise, this is a potential switch to another
+ * server, so we validate the filesize, rid, name against what we expect.
+ * The server's protocol version is returned in 'remote_version'
+ */
+static int dnbd3_execute_handshake(dnbd3_device_t *dev, struct socket *sock,
+ struct sockaddr_storage *addr, uint16_t *remote_version)
+{
+ const char *name;
+ uint64_t filesize;
+ int mlen;
+ uint16_t rid, initial_connect;
+ struct msghdr msg;
+ struct kvec iov[2];
+ serialized_buffer_t *payload;
+ dnbd3_reply_t dnbd3_reply;
+ dnbd3_request_t dnbd3_request = { .magic = dnbd3_packet_magic };
+
+ payload = kmalloc(sizeof(*payload), GFP_KERNEL);
+ if (payload == NULL)
+ goto error;
+
+ initial_connect = (dev->reported_size == 0);
+ init_msghdr(msg);
+ // Request filesize
+ dnbd3_request.cmd = CMD_SELECT_IMAGE;
+ iov[0].iov_base = &dnbd3_request;
+ iov[0].iov_len = sizeof(dnbd3_request);
+ serializer_reset_write(payload);
+ serializer_put_uint16(payload, PROTOCOL_VERSION); // DNBD3 protocol version
+ serializer_put_string(payload, dev->imgname); // image name
+ serializer_put_uint16(payload, dev->rid); // revision id
+ serializer_put_uint8(payload, 0); // are we a server? (no!)
+ iov[1].iov_base = payload;
+ dnbd3_request.size = iov[1].iov_len = serializer_get_written_length(payload);
+ fixup_request(dnbd3_request);
+ mlen = iov[0].iov_len + iov[1].iov_len;
+ if (kernel_sendmsg(sock, &msg, iov, 2, mlen) != mlen) {
+ dnbd3_dev_err_host(dev, addr, "requesting image size failed\n");
+ goto error;
+ }
+
+ // receive net reply
+ iov[0].iov_base = &dnbd3_reply;
+ iov[0].iov_len = sizeof(dnbd3_reply);
+ if (kernel_recvmsg(sock, &msg, iov, 1, sizeof(dnbd3_reply), msg.msg_flags) != sizeof(dnbd3_reply)) {
+ dnbd3_dev_err_host(dev, addr, "receiving image size packet (header) failed\n");
+ goto error;
+ }
+ fixup_reply(dnbd3_reply);
+ if (dnbd3_reply.magic != dnbd3_packet_magic || dnbd3_reply.cmd != CMD_SELECT_IMAGE || dnbd3_reply.size < 4) {
+ dnbd3_dev_err_host(dev, addr,
+ "corrupted CMD_SELECT_IMAGE reply\n");
+ goto error;
+ }
+
+ // receive data
+ iov[0].iov_base = payload;
+ iov[0].iov_len = dnbd3_reply.size;
+ if (kernel_recvmsg(sock, &msg, iov, 1, dnbd3_reply.size, msg.msg_flags) != dnbd3_reply.size) {
+ dnbd3_dev_err_host(dev, addr,
+ "receiving payload of CMD_SELECT_IMAGE reply failed\n");
+ goto error;
+ }
+ serializer_reset_read(payload, dnbd3_reply.size);
+
+ *remote_version = serializer_get_uint16(payload);
+ name = serializer_get_string(payload);
+ rid = serializer_get_uint16(payload);
+ filesize = serializer_get_uint64(payload);
+
+ if (*remote_version < MIN_SUPPORTED_SERVER) {
+ dnbd3_dev_err_host(dev, addr,
+ "server version too old (client: %d, server: %d, min supported: %d)\n",
+ (int)PROTOCOL_VERSION, (int)*remote_version,
+ (int)MIN_SUPPORTED_SERVER);
+ goto error;
+ }
+
+ if (name == NULL) {
+ dnbd3_dev_err_host(dev, addr, "server did not supply an image name\n");
+ goto error;
+ }
+ if (rid == 0) {
+ dnbd3_dev_err_host(dev, addr, "server did not supply a revision id\n");
+ goto error;
+ }
+
+ /* only check image name if this isn't the initial connect */
+ if (initial_connect && dev->rid != 0 && strcmp(name, dev->imgname) != 0) {
+ dnbd3_dev_err_host(dev, addr, "server offers image '%s', requested '%s'\n", name, dev->imgname);
+ goto error;
+ }
+
+ if (initial_connect) {
+ if (filesize < DNBD3_BLOCK_SIZE) {
+ dnbd3_dev_err_host(dev, addr, "reported size by server is < 4096\n");
+ goto error;
+ }
+ if (strlen(dev->imgname) < strlen(name)) {
+ dev->imgname = krealloc(dev->imgname, strlen(name) + 1, GFP_KERNEL);
+ if (dev->imgname == NULL) {
+ dnbd3_dev_err_host(dev, addr, "reallocating buffer for new image name failed\n");
+ goto error;
+ }
+ }
+ strcpy(dev->imgname, name);
+ dev->rid = rid;
+ // store image information
+ dev->reported_size = filesize;
+ set_capacity(dev->disk, dev->reported_size >> 9); /* 512 Byte blocks */
+ dnbd3_dev_dbg_host(dev, addr, "image size: %llu\n", dev->reported_size);
+ dev->update_available = 0;
+ } else {
+ /* switching connection, sanity checks */
+ if (rid != dev->rid) {
+ dnbd3_dev_err_host(dev, addr,
+ "server supplied wrong rid (client: '%d', server: '%d')\n",
+ (int)dev->rid, (int)rid);
+ goto error;
+ }
+
+ if (filesize != dev->reported_size) {
+ dnbd3_dev_err_host(dev, addr,
+ "reported image size of %llu does not match expected value %llu\n",
+ (unsigned long long)filesize, (unsigned long long)dev->reported_size);
+ goto error;
+ }
+ }
+ kfree(payload);
+ return 1;
+
+error:
+ kfree(payload);
+ return 0;
+}
+
+int dnbd3_request_test_block(dnbd3_device_t *dev, struct sockaddr_storage *addr, struct socket *sock)
+{
+ dnbd3_request_t dnbd3_request = { .magic = dnbd3_packet_magic };
+ dnbd3_reply_t dnbd3_reply;
+ struct kvec iov;
+ struct msghdr msg;
+ char *buf = NULL;
+ char smallbuf[256];
+ int remaining, buffer_size, ret, func_return;
+
+ init_msghdr(msg);
+
+ func_return = 0;
+ // Request block
+ dnbd3_request.cmd = CMD_GET_BLOCK;
+ // Do *NOT* pick a random block as it has proven to cause severe
+ // cache thrashing on the server
+ dnbd3_request.offset = 0;
+ dnbd3_request.size = RTT_BLOCK_SIZE;
+ fixup_request(dnbd3_request);
+ iov.iov_base = &dnbd3_request;
+ iov.iov_len = sizeof(dnbd3_request);
+
+ if (kernel_sendmsg(sock, &msg, &iov, 1, sizeof(dnbd3_request)) <= 0) {
+ dnbd3_dev_err_host(dev, addr, "requesting test block failed\n");
+ goto error;
+ }
+
+ // receive net reply
+ iov.iov_base = &dnbd3_reply;
+ iov.iov_len = sizeof(dnbd3_reply);
+ if (kernel_recvmsg(sock, &msg, &iov, 1, sizeof(dnbd3_reply), msg.msg_flags)
+ != sizeof(dnbd3_reply)) {
+ dnbd3_dev_err_host(dev, addr, "receiving test block header packet failed\n");
+ goto error;
+ }
+ fixup_reply(dnbd3_reply);
+ if (dnbd3_reply.magic != dnbd3_packet_magic || dnbd3_reply.cmd != CMD_GET_BLOCK
+ || dnbd3_reply.size != RTT_BLOCK_SIZE) {
+ dnbd3_dev_err_host(dev, addr,
+ "unexpected reply to block request: cmd=%d, size=%d (discover)\n",
+ (int)dnbd3_reply.cmd, (int)dnbd3_reply.size);
+ goto error;
+ }
+
+ // receive data
+ buf = kmalloc(DNBD3_BLOCK_SIZE, GFP_NOWAIT);
+ if (buf == NULL) {
+ /* fallback to stack if we're really memory constrained */
+ buf = smallbuf;
+ buffer_size = sizeof(smallbuf);
+ } else {
+ buffer_size = DNBD3_BLOCK_SIZE;
+ }
+ remaining = RTT_BLOCK_SIZE;
+ /* TODO in either case we could build a large iovec that points to the same buffer over and over again */
+ while (remaining > 0) {
+ iov.iov_base = buf;
+ iov.iov_len = buffer_size;
+ ret = kernel_recvmsg(sock, &msg, &iov, 1, MIN(remaining, buffer_size), msg.msg_flags);
+ if (ret <= 0) {
+ dnbd3_dev_err_host(dev, addr, "receiving test block payload failed (ret=%d)\n", ret);
+ goto error;
+ }
+ remaining -= ret;
+ }
+ func_return = 1;
+ // Fallthrough!
+error:
+ if (buf != smallbuf)
+ kfree(buf);
+ return func_return;
+}
+
+static int spawn_worker_thread(dnbd3_device_t *dev, struct task_struct **task, const char *name,
+ int (*threadfn)(void *data))
+{
+ ASSERT(*task == NULL);
+ *task = kthread_create(threadfn, dev, "%s-%s", dev->disk->disk_name, name);
+ if (!IS_ERR(*task)) {
+ get_task_struct(*task);
+ wake_up_process(*task);
+ return 1;
+ }
+ dev_err(dnbd3_device_to_dev(dev), "failed to create %s thread (%ld)\n",
+ name, PTR_ERR(*task));
+ /* reset possible non-NULL error value */
+ *task = NULL;
+ return 0;
+}
+
+static void stop_worker_thread(dnbd3_device_t *dev, struct task_struct **task, const char *name, int quiet)
+{
+ int ret;
+
+ if (*task == NULL)
+ return;
+ if (!quiet)
+ dnbd3_dev_dbg_host_cur(dev, "stop %s thread\n", name);
+ ret = kthread_stop(*task);
+ put_task_struct(*task);
+ if (ret == -EINTR) {
+ /* thread has never been scheduled and run */
+ if (!quiet)
+ dev_dbg(dnbd3_device_to_dev(dev), "%s thread has never run\n", name);
+ } else {
+ /* thread has run, check if it has terminated successfully */
+ if (ret < 0 && !quiet)
+ dev_err(dnbd3_device_to_dev(dev), "%s thread was not terminated correctly\n", name);
+ }
+ *task = NULL;
+}
+
int dnbd3_net_connect(dnbd3_device_t *dev)
{
- struct request *req1 = NULL;
+ struct request *req_alt_servers = NULL;
unsigned long irqflags;
ASSERT(atomic_read(&dev->connection_lock));
- // do some checks before connecting
- req1 = kmalloc(sizeof(*req1), GFP_ATOMIC);
- if (!req1) {
- dnbd3_dev_err_host_cur(dev, "kmalloc failed\n");
- goto error;
+ if (dev->use_server_provided_alts) {
+ req_alt_servers = kmalloc(sizeof(*req_alt_servers), GFP_KERNEL);
+ if (req_alt_servers == NULL)
+ dnbd3_dev_err_host_cur(dev, "Cannot allocate memory to request list of alt servers\n");
}
if (dev->cur_server.host.ss_family == 0 || dev->imgname == NULL) {
@@ -964,165 +1086,50 @@ int dnbd3_net_connect(dnbd3_device_t *dev)
ASSERT(dev->thread_receive == NULL);
ASSERT(dev->thread_discover == NULL);
- dnbd3_dev_dbg_host_cur(dev, "connecting ...\n");
-
if (dev->better_sock != NULL) {
// Switching server, connection is already established and size request was executed
- dnbd3_dev_dbg_host_cur(dev, "on-the-fly server change ...\n");
+ dnbd3_dev_dbg_host_cur(dev, "on-the-fly server change\n");
dev->sock = dev->better_sock;
dev->better_sock = NULL;
} else {
// no established connection yet from discovery thread, start new one
- uint64_t reported_size;
- dnbd3_request_t dnbd3_request;
- dnbd3_reply_t dnbd3_reply;
- struct msghdr msg;
- struct kvec iov[2];
- uint16_t rid, proto_version;
- char *name;
- int mlen;
-
- init_msghdr(msg);
+ uint16_t proto_version;
+ dnbd3_dev_dbg_host_cur(dev, "connecting\n");
dev->sock = dnbd3_connect(dev, &dev->cur_server.host);
if (dev->sock == NULL) {
dnbd3_dev_err_host_cur(dev, "%s: Failed\n", __func__);
goto error;
}
-
- // Request filesize
- dnbd3_request.magic = dnbd3_packet_magic;
- dnbd3_request.cmd = CMD_SELECT_IMAGE;
- iov[0].iov_base = &dnbd3_request;
- iov[0].iov_len = sizeof(dnbd3_request);
- serializer_reset_write(&dev->payload_buffer);
- serializer_put_uint16(&dev->payload_buffer, PROTOCOL_VERSION);
- serializer_put_string(&dev->payload_buffer, dev->imgname);
- serializer_put_uint16(&dev->payload_buffer, dev->rid);
- serializer_put_uint8(&dev->payload_buffer, 0); // is_server = false
- iov[1].iov_base = &dev->payload_buffer;
- dnbd3_request.size = iov[1].iov_len = serializer_get_written_length(&dev->payload_buffer);
- fixup_request(dnbd3_request);
- mlen = sizeof(dnbd3_request) + iov[1].iov_len;
- if (kernel_sendmsg(dev->sock, &msg, iov, 2, mlen) != mlen) {
- dnbd3_dev_err_host_cur(dev, "couldn't send CMD_SELECT_IMAGE\n");
- goto error;
- }
- // receive reply header
- iov[0].iov_base = &dnbd3_reply;
- iov[0].iov_len = sizeof(dnbd3_reply);
- if (kernel_recvmsg(dev->sock, &msg, iov, 1, sizeof(dnbd3_reply), msg.msg_flags) !=
- sizeof(dnbd3_reply)) {
- dnbd3_dev_err_host_cur(dev, "received corrupted reply header after CMD_SELECT_IMAGE\n");
+ /* execute the "select image" handshake */
+ if (!dnbd3_execute_handshake(dev, dev->sock, &dev->cur_server.host, &proto_version))
goto error;
- }
- // check reply header
- fixup_reply(dnbd3_reply);
- if (dnbd3_reply.cmd != CMD_SELECT_IMAGE || dnbd3_reply.size < 3 || dnbd3_reply.size > MAX_PAYLOAD ||
- dnbd3_reply.magic != dnbd3_packet_magic) {
- dnbd3_dev_err_host_cur(
- dev, "received invalid reply to CMD_SELECT_IMAGE, image doesn't exist on server\n");
- goto error;
- }
- // receive reply payload
- iov[0].iov_base = &dev->payload_buffer;
- iov[0].iov_len = dnbd3_reply.size;
- if (kernel_recvmsg(dev->sock, &msg, iov, 1, dnbd3_reply.size, msg.msg_flags) != dnbd3_reply.size) {
- dnbd3_dev_err_host_cur(dev, "cold not read CMD_SELECT_IMAGE payload on handshake\n");
- goto error;
- }
- // handle/check reply payload
- serializer_reset_read(&dev->payload_buffer, dnbd3_reply.size);
- proto_version = serializer_get_uint16(&dev->payload_buffer);
+
spin_lock_irqsave(&dev->blk_lock, irqflags);
dev->cur_server.protocol_version = proto_version;
spin_unlock_irqrestore(&dev->blk_lock, irqflags);
- if (proto_version < MIN_SUPPORTED_SERVER) {
- dnbd3_dev_err_host_cur(dev, "server version is lower than min supported version\n");
- goto error;
- }
- name = serializer_get_string(&dev->payload_buffer);
- if (dev->rid != 0 && strcmp(name, dev->imgname) != 0) {
- dnbd3_dev_err_host_cur(dev, "server offers image '%s', requested '%s'\n", name, dev->imgname);
- goto error;
- }
- if (strlen(dev->imgname) < strlen(name)) {
- dev->imgname = krealloc(dev->imgname, strlen(name) + 1, GFP_ATOMIC);
- if (dev->imgname == NULL) {
- dnbd3_dev_err_host_cur(dev, "reallocating buffer for new image name failed\n");
- goto error;
- }
- }
- strcpy(dev->imgname, name);
- rid = serializer_get_uint16(&dev->payload_buffer);
- if (dev->rid != 0 && dev->rid != rid) {
- dnbd3_dev_err_host_cur(dev, "server provides rid %d, requested was %d\n", (int)rid,
- (int)dev->rid);
- goto error;
- }
- dev->rid = rid;
- reported_size = serializer_get_uint64(&dev->payload_buffer);
- if (reported_size < 4096) {
- dnbd3_dev_err_host_cur(dev, "reported size by server is < 4096\n");
- goto error;
- }
- if (dev->reported_size != 0 && dev->reported_size != reported_size) {
- dnbd3_dev_err_host_cur(dev, "newly connected server reports size %llu, but expected is %llu\n",
- reported_size, dev->reported_size);
- goto error;
- } else if (dev->reported_size == 0) {
- // store image information
- dev->reported_size = reported_size;
- set_capacity(dev->disk, dev->reported_size >> 9); /* 512 Byte blocks */
- dnbd3_dev_dbg_host_cur(dev, "image size: %llu\n", dev->reported_size);
- dev->update_available = 0;
- }
}
- // create required threads
- dev->thread_send = kthread_create(dnbd3_net_send, dev, "%s-send", dev->disk->disk_name);
- if (!IS_ERR(dev->thread_send)) {
- get_task_struct(dev->thread_send);
- wake_up_process(dev->thread_send);
- } else {
- dev_err(dnbd3_device_to_dev(dev), "failed to create send thread\n");
- /* reset error to cleanup thread */
- dev->thread_send = NULL;
+ /* create required threads */
+ if (!spawn_worker_thread(dev, &dev->thread_send, "send", dnbd3_net_send))
goto error;
- }
-
- dev->thread_receive = kthread_create(dnbd3_net_receive, dev, "%s-receive", dev->disk->disk_name);
- if (!IS_ERR(dev->thread_receive)) {
- get_task_struct(dev->thread_receive);
- wake_up_process(dev->thread_receive);
- } else {
- dev_err(dnbd3_device_to_dev(dev), "failed to create receive thread\n");
- /* reset error to cleanup thread */
- dev->thread_receive = NULL;
+ if (!spawn_worker_thread(dev, &dev->thread_receive, "receive", dnbd3_net_receive))
goto error;
- }
-
- dev->thread_discover = kthread_create(dnbd3_net_discover, dev, "%s-discover", dev->disk->disk_name);
- if (!IS_ERR(dev->thread_discover)) {
- get_task_struct(dev->thread_discover);
- wake_up_process(dev->thread_discover);
- } else {
- dev_err(dnbd3_device_to_dev(dev), "failed to create discover thread\n");
- /* reset error to cleanup thread */
- dev->thread_discover = NULL;
+ if (!spawn_worker_thread(dev, &dev->thread_discover, "discover", dnbd3_net_discover))
goto error;
- }
+ dnbd3_dev_dbg_host_cur(dev, "connection established\n");
dev->panic = 0;
dev->panic_count = 0;
- // Enqueue request to request_queue_send for a fresh list of alt servers
- dnbd3_cmd_to_priv(req1, CMD_GET_SERVERS);
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- list_add(&req1->queuelist, &dev->request_queue_send);
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
-
- wake_up(&dev->process_queue_send);
+ if (req_alt_servers != NULL) {
+ // Enqueue request to request_queue_send for a fresh list of alt servers
+ dnbd3_cmd_to_priv(req_alt_servers, CMD_GET_SERVERS);
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
+ list_add(&req_alt_servers->queuelist, &dev->request_queue_send);
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
+ wake_up(&dev->process_queue_send);
+ }
// add heartbeat timer
// Do not goto error after creating the timer - we require that the timer exists
@@ -1135,21 +1142,9 @@ int dnbd3_net_connect(dnbd3_device_t *dev)
return 0;
error:
- if (dev->thread_send) {
- kthread_stop(dev->thread_send);
- put_task_struct(dev->thread_send);
- dev->thread_send = NULL;
- }
- if (dev->thread_receive) {
- kthread_stop(dev->thread_receive);
- put_task_struct(dev->thread_receive);
- dev->thread_receive = NULL;
- }
- if (dev->thread_discover) {
- kthread_stop(dev->thread_discover);
- put_task_struct(dev->thread_discover);
- dev->thread_discover = NULL;
- }
+ stop_worker_thread(dev, &dev->thread_send, "send", 1);
+ stop_worker_thread(dev, &dev->thread_receive, "receive", 1);
+ stop_worker_thread(dev, &dev->thread_discover, "discover", 1);
if (dev->sock) {
sock_release(dev->sock);
dev->sock = NULL;
@@ -1157,16 +1152,14 @@ error:
spin_lock_irqsave(&dev->blk_lock, irqflags);
dev->cur_server.host.ss_family = 0;
spin_unlock_irqrestore(&dev->blk_lock, irqflags);
- kfree(req1);
+ kfree(req_alt_servers);
return -1;
}
int dnbd3_net_disconnect(dnbd3_device_t *dev)
{
- struct task_struct *thread = NULL;
unsigned long irqflags;
- int ret;
dev_dbg(dnbd3_device_to_dev(dev), "disconnecting device ...\n");
ASSERT(atomic_read(&dev->connection_lock));
@@ -1180,56 +1173,9 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev)
}
// kill sending and receiving threads
- if (dev->thread_send) {
- dnbd3_dev_dbg_host_cur(dev, "stop send thread\n");
- thread = dev->thread_send;
- ret = kthread_stop(thread);
- put_task_struct(thread);
- if (ret == -EINTR) {
- /* thread has never been scheduled and run */
- dev_dbg(dnbd3_device_to_dev(dev), "send thread has never run\n");
- } else {
- /* thread has run, check if it has terminated successfully */
- if (ret < 0)
- dev_err(dnbd3_device_to_dev(dev), "send thread was not terminated correctly\n");
- }
- dev->thread_send = NULL;
- }
-
- if (dev->thread_receive) {
- dnbd3_dev_dbg_host_cur(dev, "stop receive thread\n");
- thread = dev->thread_receive;
- ret = kthread_stop(thread);
- put_task_struct(thread);
- if (ret == -EINTR) {
- /* thread has never been scheduled and run */
- dev_dbg(dnbd3_device_to_dev(dev), "receive thread has never run\n");
- } else {
- /* thread has run, check if it has terminated successfully */
- if (ret < 0)
- dev_err(dnbd3_device_to_dev(dev), "receive thread was not terminated correctly\n");
- }
- dev->thread_receive = NULL;
- }
-
- if (dev->thread_discover) {
- dnbd3_dev_dbg_host_cur(dev, "stop discover thread\n");
- thread = dev->thread_discover;
- ret = kthread_stop(thread);
- put_task_struct(thread);
- if (ret == -EINTR) {
- /* thread has never been scheduled and run */
- dev_dbg(dnbd3_device_to_dev(dev), "discover thread has never run\n");
- } else {
- /* thread has run, check if it has terminated successfully */
- if (ret < 0) {
- dev_err(dnbd3_device_to_dev(dev), "discover thread was not terminated correctly (%d)\n",
- ret);
- }
- }
- dev->thread_discover = NULL;
- }
-
+ stop_worker_thread(dev, &dev->thread_send, "send", 0);
+ stop_worker_thread(dev, &dev->thread_receive, "receive", 0);
+ stop_worker_thread(dev, &dev->thread_discover, "discover", 0);
if (dev->sock) {
sock_release(dev->sock);
dev->sock = NULL;
diff --git a/src/kernel/net.h b/src/kernel/net.h
index d46505b..f91334e 100644
--- a/src/kernel/net.h
+++ b/src/kernel/net.h
@@ -24,15 +24,6 @@
#include "dnbd3_main.h"
-#define init_msghdr(h) \
- do { \
- h.msg_name = NULL; \
- h.msg_namelen = 0; \
- h.msg_control = NULL; \
- h.msg_controllen = 0; \
- h.msg_flags = MSG_WAITALL | MSG_NOSIGNAL; \
- } while (0)
-
int dnbd3_net_connect(dnbd3_device_t *lo);
int dnbd3_net_disconnect(dnbd3_device_t *lo);