summaryrefslogtreecommitdiffstats
path: root/src/kernel/net.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/kernel/net.c')
-rw-r--r--src/kernel/net.c1929
1 files changed, 968 insertions, 961 deletions
diff --git a/src/kernel/net.c b/src/kernel/net.c
index 9e48b86..5ef4016 100644
--- a/src/kernel/net.c
+++ b/src/kernel/net.c
@@ -1,9 +1,10 @@
+// SPDX-License-Identifier: GPL-2.0
/*
* This file is part of the Distributed Network Block Device 3
*
* Copyright(c) 2011-2012 Johann Latocha <johann@latocha.de>
*
- * This file may be licensed under the terms of of the
+ * This file may be licensed under the terms of the
* GNU General Public License Version 2 (the ``GPL'').
*
* Software distributed under the License is distributed
@@ -18,1106 +19,1112 @@
*
*/
-#include "clientconfig.h"
+#include <dnbd3/config/client.h>
#include "net.h"
#include "blk.h"
-#include "utils.h"
+#include "dnbd3_main.h"
-#include "serialize.h"
+#include <dnbd3/shared/serialize.h>
+
+#include <linux/random.h>
+#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 15, 0)
+#define get_random_u32 prandom_u32
+#endif
#include <linux/time.h>
-#include <linux/signal.h>
+#include <linux/ktime.h>
+#include <linux/tcp.h>
#ifndef MIN
-#define MIN(a,b) ((a) < (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif
-#if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 2, 0)
-#define dnbd3_sock_create(af,type,proto,sock) sock_create_kern(&init_net, (af) == HOST_IP4 ? AF_INET : AF_INET6, type, proto, sock)
-#else
-#define dnbd3_sock_create(af,type,proto,sock) sock_create_kern((af) == HOST_IP4 ? AF_INET : AF_INET6, type, proto, sock)
+#ifndef ktime_to_s
+#define ktime_to_s(kt) ktime_divns(kt, NSEC_PER_SEC)
#endif
-#if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 11, 0)
-// cmd_flags and cmd_type are merged into cmd_flags now
-#if REQ_FLAG_BITS > 24
-#error "Fix CMD bitshift"
-#endif
-// Pack into cmd_flags field by shifting CMD_* into unused bits of cmd_flags
-#define dnbd3_cmd_to_priv(req, cmd) (req)->cmd_flags = REQ_OP_DRV_IN | ((cmd) << REQ_FLAG_BITS)
-#define dnbd3_priv_to_cmd(req) ((req)->cmd_flags >> REQ_FLAG_BITS)
-#define dnbd3_req_op(req) req_op(req)
-#define DNBD3_DEV_READ REQ_OP_READ
-#define DNBD3_REQ_OP_SPECIAL REQ_OP_DRV_IN
+#ifdef DEBUG
+#define ASSERT(x) \
+ do { \
+ if (!(x)) { \
+ printk(KERN_EMERG "assertion failed %s: %d: %s\n", __FILE__, __LINE__, #x); \
+ BUG(); \
+ } \
+ } while (0)
#else
-// Old way with type and flags separated
-#define dnbd3_cmd_to_priv(req, cmd) do { \
- (req)->cmd_type = REQ_TYPE_SPECIAL; \
- (req)->cmd_flags = (cmd); \
-} while (0)
-#define dnbd3_priv_to_cmd(req) (req)->cmd_flags
-#define dnbd3_req_op(req) (req)->cmd_type
-#define DNBD3_DEV_READ REQ_TYPE_FS
-#define DNBD3_REQ_OP_SPECIAL REQ_TYPE_SPECIAL
+#define ASSERT(x) \
+ do { \
+ } while (0)
#endif
-/**
- * Some macros for easier debug output. Location in source-code
- * as well as server IP:port info will be printed.
- * The error_* macros include a "goto error;" at the end
- */
-#if 1 // Change to 0 to disable debug messages
-#define debug_print_va_host(_host, _fmt, ...) do { \
- if ((_host).type == HOST_IP4) \
- printk("%s:%d " _fmt " (%s, %pI4:%d)\n", __FILE__, __LINE__, __VA_ARGS__, dev->disk->disk_name, (_host).addr, (int)ntohs((_host).port)); \
- else \
- printk("%s:%d " _fmt " (%s, [%pI6]:%d)\n", __FILE__, __LINE__, __VA_ARGS__, dev->disk->disk_name, (_host).addr, (int)ntohs((_host).port)); \
-} while(0)
-#define debug_error_va_host(_host, _fmt, ...) do { \
- debug_print_va_host(_host, _fmt, __VA_ARGS__); \
- goto error; \
-} while(0)
-#define debug_dev_va(_fmt, ...) debug_print_va_host(dev->cur_server.host, _fmt, __VA_ARGS__)
-#define error_dev_va(_fmt, ...) debug_error_va_host(dev->cur_server.host, _fmt, __VA_ARGS__)
-#define debug_alt_va(_fmt, ...) debug_print_va_host(dev->alt_servers[i].host, _fmt, __VA_ARGS__)
-#define error_alt_va(_fmt, ...) debug_error_va_host(dev->alt_servers[i].host, _fmt, __VA_ARGS__)
-
-#define debug_print_host(_host, txt) do { \
- if ((_host).type == HOST_IP4) \
- printk("%s:%d " txt " (%s, %pI4:%d)\n", __FILE__, __LINE__, dev->disk->disk_name, (_host).addr, (int)ntohs((_host).port)); \
- else \
- printk("%s:%d " txt " (%s, [%pI6]:%d)\n", __FILE__, __LINE__, dev->disk->disk_name, (_host).addr, (int)ntohs((_host).port)); \
-} while(0)
-#define debug_error_host(_host, txt) do { \
- debug_print_host(_host, txt); \
- goto error; \
-} while(0)
-#define debug_dev(txt) debug_print_host(dev->cur_server.host, txt)
-#define error_dev(txt) debug_error_host(dev->cur_server.host, txt)
-#define debug_alt(txt) debug_print_host(dev->alt_servers[i].host, txt)
-#define error_alt(txt) debug_error_host(dev->alt_servers[i].host, txt)
-
-#else // Silent
-#define debug_dev(x) do { } while(0)
-#define error_dev(x) goto error
-#define debug_dev_va(x, ...) do { } while(0)
-#define error_dev_va(x, ...) goto error
-#define debug_alt(x) do { } while(0)
-#define error_alt(x) goto error
-#define debug_alt_va(x, ...) do { } while(0)
-#define error_alt_va(x, ...) goto error
-#endif
+#define dnbd3_dev_dbg_host(dev, host, fmt, ...) \
+ dev_dbg(dnbd3_device_to_dev(dev), "(%pISpc): " fmt, (host), ##__VA_ARGS__)
+#define dnbd3_dev_info_host(dev, host, fmt, ...) \
+ dev_info(dnbd3_device_to_dev(dev), "(%pISpc): " fmt, (host), ##__VA_ARGS__)
+#define dnbd3_dev_err_host(dev, host, fmt, ...) \
+ dev_err(dnbd3_device_to_dev(dev), "(%pISpc): " fmt, (host), ##__VA_ARGS__)
-static inline int is_same_server(const dnbd3_server_t * const a, const dnbd3_server_t * const b)
-{
- return (a->host.type == b->host.type) && (a->host.port == b->host.port)
- && (0 == memcmp(a->host.addr, b->host.addr, (a->host.type == HOST_IP4 ? 4 : 16)));
-}
+#define dnbd3_dev_dbg_cur(dev, fmt, ...) \
+ dnbd3_dev_dbg_host(dev, &(dev)->cur_server.host, fmt, ##__VA_ARGS__)
+#define dnbd3_dev_info_cur(dev, fmt, ...) \
+ dnbd3_dev_info_host(dev, &(dev)->cur_server.host, fmt, ##__VA_ARGS__)
+#define dnbd3_dev_err_cur(dev, fmt, ...) \
+ dnbd3_dev_err_host(dev, &(dev)->cur_server.host, fmt, ##__VA_ARGS__)
-static inline dnbd3_server_t *get_existing_server(const dnbd3_server_entry_t * const newserver,
- dnbd3_device_t * const dev)
-{
- int i;
- for (i = 0; i < NUMBER_SERVERS; ++i)
- {
- if ((newserver->host.type == dev->alt_servers[i].host.type)
- && (newserver->host.port == dev->alt_servers[i].host.port)
- && (0
- == memcmp(newserver->host.addr, dev->alt_servers[i].host.addr, (newserver->host.type == HOST_IP4 ? 4 : 16))))
- {
- return &dev->alt_servers[i];
- break;
- }
- }
- return NULL ;
-}
-
-static inline dnbd3_server_t *get_free_alt_server(dnbd3_device_t * const dev)
-{
- int i;
- for (i = 0; i < NUMBER_SERVERS; ++i)
- {
- if (dev->alt_servers[i].host.type == 0)
- return &dev->alt_servers[i];
- }
- for (i = 0; i < NUMBER_SERVERS; ++i)
- {
- if (dev->alt_servers[i].failures > 10)
- return &dev->alt_servers[i];
- }
- return NULL ;
-}
+static bool dnbd3_drain_socket(dnbd3_device_t *dev, struct socket *sock, int bytes);
+static int dnbd3_recv_bytes(struct socket *sock, void *buffer, size_t count);
+static int dnbd3_recv_reply(struct socket *sock, dnbd3_reply_t *reply_hdr);
+static bool dnbd3_send_request(struct socket *sock, u16 cmd, u64 handle, u64 offset, u32 size);
-int dnbd3_net_connect(dnbd3_device_t *dev)
-{
- struct request *req1 = NULL;
- struct timeval timeout;
+static int dnbd3_set_primary_connection(dnbd3_device_t *dev, struct socket *sock,
+ struct sockaddr_storage *addr, u16 protocol_version);
- if (dev->disconnecting) {
- debug_dev("CONNECT: Still disconnecting!!!\n");
- while (dev->disconnecting)
- schedule();
- }
- if (dev->thread_receive != NULL) {
- debug_dev("CONNECT: Still receiving!!!\n");
- while (dev->thread_receive != NULL)
- schedule();
- }
- if (dev->thread_send != NULL) {
- debug_dev("CONNECT: Still sending!!!\n");
- while (dev->thread_send != NULL)
- schedule();
- }
+static int dnbd3_connect(dnbd3_device_t *dev, struct sockaddr_storage *addr,
+ struct socket **sock_out);
- timeout.tv_sec = SOCKET_TIMEOUT_CLIENT_DATA;
- timeout.tv_usec = 0;
+static bool dnbd3_execute_handshake(dnbd3_device_t *dev, struct socket *sock,
+ struct sockaddr_storage *addr, uint16_t *remote_version, bool copy_image_info);
- // do some checks before connecting
+static bool dnbd3_request_test_block(dnbd3_device_t *dev, struct sockaddr_storage *addr,
+ struct socket *sock);
- req1 = kmalloc(sizeof(*req1), GFP_ATOMIC );
- if (!req1)
- error_dev("FATAL: Kmalloc(1) failed.");
+static bool dnbd3_send_empty_request(dnbd3_device_t *dev, u16 cmd);
- if (dev->cur_server.host.port == 0 || dev->cur_server.host.type == 0 || dev->imgname == NULL )
- error_dev("FATAL: Host, port or image name not set.");
- if (dev->sock)
- error_dev("ERROR: Already connected.");
-
- if (dev->cur_server.host.type != HOST_IP4 && dev->cur_server.host.type != HOST_IP6)
- error_dev_va("ERROR: Unknown address type %d", (int)dev->cur_server.host.type);
-
- debug_dev("INFO: Connecting...");
-
- if (dev->better_sock == NULL )
- {
- // no established connection yet from discovery thread, start new one
- dnbd3_request_t dnbd3_request;
- dnbd3_reply_t dnbd3_reply;
- struct msghdr msg;
- struct kvec iov[2];
- uint16_t rid;
- char *name;
- int mlen;
- init_msghdr(msg);
-
- if (dnbd3_sock_create(dev->cur_server.host.type, SOCK_STREAM, IPPROTO_TCP, &dev->sock) < 0)
- error_dev("ERROR: Couldn't create socket (v6).");
-
- kernel_setsockopt(dev->sock, SOL_SOCKET, SO_SNDTIMEO, (char *)&timeout, sizeof(timeout));
- kernel_setsockopt(dev->sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&timeout, sizeof(timeout));
- dev->sock->sk->sk_allocation = GFP_NOIO;
- if (dev->cur_server.host.type == HOST_IP4)
- {
- struct sockaddr_in sin;
- memset(&sin, 0, sizeof(sin));
- sin.sin_family = AF_INET;
- memcpy(&(sin.sin_addr), dev->cur_server.host.addr, 4);
- sin.sin_port = dev->cur_server.host.port;
- if (kernel_connect(dev->sock, (struct sockaddr *)&sin, sizeof(sin), 0) != 0)
- error_dev("FATAL: Connection to host failed. (v4)");
- }
- else
- {
- struct sockaddr_in6 sin;
- memset(&sin, 0, sizeof(sin));
- sin.sin6_family = AF_INET6;
- memcpy(&(sin.sin6_addr), dev->cur_server.host.addr, 16);
- sin.sin6_port = dev->cur_server.host.port;
- if (kernel_connect(dev->sock, (struct sockaddr *)&sin, sizeof(sin), 0) != 0)
- error_dev("FATAL: Connection to host failed. (v6)");
- }
- // Request filesize
- dnbd3_request.magic = dnbd3_packet_magic;
- dnbd3_request.cmd = CMD_SELECT_IMAGE;
- iov[0].iov_base = &dnbd3_request;
- iov[0].iov_len = sizeof(dnbd3_request);
- serializer_reset_write(&dev->payload_buffer);
- serializer_put_uint16(&dev->payload_buffer, PROTOCOL_VERSION);
- serializer_put_string(&dev->payload_buffer, dev->imgname);
- serializer_put_uint16(&dev->payload_buffer, dev->rid);
- serializer_put_uint8(&dev->payload_buffer, 0); // is_server = false
- iov[1].iov_base = &dev->payload_buffer;
- dnbd3_request.size = iov[1].iov_len = serializer_get_written_length(&dev->payload_buffer);
- fixup_request(dnbd3_request);
- mlen = sizeof(dnbd3_request) + iov[1].iov_len;
- if (kernel_sendmsg(dev->sock, &msg, iov, 2, mlen) != mlen)
- error_dev("ERROR: Couldn't send CMD_SIZE_REQUEST.");
- // receive reply header
- iov[0].iov_base = &dnbd3_reply;
- iov[0].iov_len = sizeof(dnbd3_reply);
- if (kernel_recvmsg(dev->sock, &msg, iov, 1, sizeof(dnbd3_reply), msg.msg_flags) != sizeof(dnbd3_reply))
- error_dev("FATAL: Received corrupted reply header after CMD_SIZE_REQUEST.");
- // check reply header
- fixup_reply(dnbd3_reply);
- if (dnbd3_reply.cmd != CMD_SELECT_IMAGE || dnbd3_reply.size < 3 || dnbd3_reply.size > MAX_PAYLOAD
- || dnbd3_reply.magic != dnbd3_packet_magic)
- error_dev("FATAL: Received invalid reply to CMD_SIZE_REQUEST, image doesn't exist on server.");
- // receive reply payload
- iov[0].iov_base = &dev->payload_buffer;
- iov[0].iov_len = dnbd3_reply.size;
- if (kernel_recvmsg(dev->sock, &msg, iov, 1, dnbd3_reply.size, msg.msg_flags) != dnbd3_reply.size)
- error_dev("FATAL: Cold not read CMD_SELECT_IMAGE payload on handshake.");
- // handle/check reply payload
- serializer_reset_read(&dev->payload_buffer, dnbd3_reply.size);
- dev->cur_server.protocol_version = serializer_get_uint16(&dev->payload_buffer);
- if (dev->cur_server.protocol_version < MIN_SUPPORTED_SERVER)
- error_dev("FATAL: Server version is lower than min supported version.");
- name = serializer_get_string(&dev->payload_buffer);
- if (dev->rid != 0 && strcmp(name, dev->imgname) != 0)
- error_dev_va("FATAL: Server offers image '%s', requested '%s'", name, dev->imgname);
- if (strlen(dev->imgname) < strlen(name))
- {
- dev->imgname = krealloc(dev->imgname, strlen(name) + 1, GFP_ATOMIC );
- if (dev->imgname == NULL )
- error_dev("FATAL: Reallocating buffer for new image name failed");
- }
- strcpy(dev->imgname, name);
- rid = serializer_get_uint16(&dev->payload_buffer);
- if (dev->rid != 0 && dev->rid != rid)
- error_dev_va("FATAL: Server provides rid %d, requested was %d.", (int)rid, (int)dev->rid);
- dev->rid = rid;
- dev->reported_size = serializer_get_uint64(&dev->payload_buffer);
- if (dev->reported_size < 4096)
- error_dev("ERROR: Reported size by server is < 4096");
- // store image information
- set_capacity(dev->disk, dev->reported_size >> 9); /* 512 Byte blocks */
- debug_dev_va("INFO: Filesize: %llu.", dev->reported_size);
- dev->update_available = 0;
- }
- else // Switching server, connection is already established and size request was executed
- {
- debug_dev("INFO: On-the-fly server change.");
- dev->sock = dev->better_sock;
- dev->better_sock = NULL;
- kernel_setsockopt(dev->sock, SOL_SOCKET, SO_SNDTIMEO, (char *)&timeout, sizeof(timeout));
- kernel_setsockopt(dev->sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&timeout, sizeof(timeout));
- }
+static void dnbd3_start_discover(dnbd3_device_t *dev, bool panic);
- dev->panic = 0;
- dev->panic_count = 0;
+static void dnbd3_discover(dnbd3_device_t *dev);
- // Enqueue request to request_queue_send for a fresh list of alt servers
- dnbd3_cmd_to_priv(req1, CMD_GET_SERVERS);
- list_add(&req1->queuelist, &dev->request_queue_send);
+static void dnbd3_internal_discover(dnbd3_device_t *dev);
- // create required threads
- dev->thread_send = kthread_create(dnbd3_net_send, dev, dev->disk->disk_name);
- dev->thread_receive = kthread_create(dnbd3_net_receive, dev, dev->disk->disk_name);
- dev->thread_discover = kthread_create(dnbd3_net_discover, dev, dev->disk->disk_name);
- // start them up
- wake_up_process(dev->thread_send);
- wake_up_process(dev->thread_receive);
- wake_up_process(dev->thread_discover);
+static void set_socket_timeout(struct socket *sock, bool set_send, int timeout_ms);
- wake_up(&dev->process_queue_send);
+// Use as write-only dump, don't care about race conditions etc.
+static u8 __garbage_mem[PAGE_SIZE];
- // add heartbeat timer
- dev->heartbeat_count = 0;
+/**
+ * Delayed work triggering sending of keepalive packet.
+ */
+static void dnbd3_keepalive_workfn(struct work_struct *work)
+{
+ unsigned long irqflags;
+ dnbd3_device_t *dev = container_of(work, dnbd3_device_t, keepalive_work.work);
-// init_timer_key changed from kernel version 4.14 to 4.15, see and compare to 4.15:
-// https://elixir.bootlin.com/linux/v4.14.32/source/include/linux/timer.h#L98
-#if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 15, 0)
- timer_setup(&dev->hb_timer, dnbd3_net_heartbeat, 0);
-#else
- // Old timer setup
- init_timer(&dev->hb_timer);
- dev->hb_timer.data = (unsigned long)dev;
- dev->hb_timer.function = dnbd3_net_heartbeat;
-#endif
- dev->hb_timer.expires = jiffies + HZ;
- add_timer(&dev->hb_timer);
- return 0;
- error: ;
- if (dev->sock)
- {
- sock_release(dev->sock);
- dev->sock = NULL;
+ dnbd3_send_empty_request(dev, CMD_KEEPALIVE);
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
+ if (device_active(dev)) {
+ mod_delayed_work(system_freezable_power_efficient_wq,
+ &dev->keepalive_work, KEEPALIVE_INTERVAL * HZ);
}
- dev->cur_server.host.type = 0;
- dev->cur_server.host.port = 0;
- if (req1)
- kfree(req1);
- return -1;
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
}
-int dnbd3_net_disconnect(dnbd3_device_t *dev)
+/**
+ * Delayed work triggering discovery (alt server check)
+ */
+static void dnbd3_discover_workfn(struct work_struct *work)
{
- if (dev->disconnecting)
- return 0;
-
- if (dev->cur_server.host.port)
- debug_dev("INFO: Disconnecting device.");
-
- dev->disconnecting = 1;
-
- // clear heartbeat timer
- del_timer(&dev->hb_timer);
-
- dev->discover = 0;
-
- if (dev->sock)
- kernel_sock_shutdown(dev->sock, SHUT_RDWR);
-
- // kill sending and receiving threads
- if (dev->thread_send)
- {
- kthread_stop(dev->thread_send);
- }
+ dnbd3_device_t *dev = container_of(work, dnbd3_device_t, discover_work.work);
- if (dev->thread_receive)
- {
- kthread_stop(dev->thread_receive);
- }
+ dnbd3_discover(dev);
+}
- if (dev->thread_discover)
- {
- kthread_stop(dev->thread_discover);
- dev->thread_discover = NULL;
- }
+/**
+ * For manually triggering an immediate discovery
+ */
+static void dnbd3_start_discover(dnbd3_device_t *dev, bool panic)
+{
+ unsigned long irqflags;
- // clear socket
- if (dev->sock)
- {
- sock_release(dev->sock);
- dev->sock = NULL;
+ if (!device_active(dev))
+ return;
+ if (panic && dnbd3_flag_get(dev->connection_lock)) {
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
+ if (!dev->panic) {
+ // Panic freshly turned on
+ dev->panic = true;
+ dev->discover_interval = TIMER_INTERVAL_PROBE_PANIC;
+ }
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
+ dnbd3_flag_reset(dev->connection_lock);
}
- dev->cur_server.host.type = 0;
- dev->cur_server.host.port = 0;
-
- dev->disconnecting = 0;
-
- return 0;
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
+ mod_delayed_work(system_freezable_power_efficient_wq,
+ &dev->discover_work, 1);
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
}
-#if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 15, 0)
-void dnbd3_net_heartbeat(struct timer_list *arg)
-{
- dnbd3_device_t *dev = (dnbd3_device_t *)container_of(arg, dnbd3_device_t, hb_timer);
-#else
-void dnbd3_net_heartbeat(unsigned long arg)
+/**
+ * Wrapper for the actual discover function below. Check run conditions
+ * here and re-schedule delayed task here.
+ */
+static void dnbd3_discover(dnbd3_device_t *dev)
{
- dnbd3_device_t *dev = (dnbd3_device_t *)arg;
-#endif
- // Because different events need different intervals, the timer is called once a second.
- // Other intervals can be derived using dev->heartbeat_count.
-#define timeout_seconds(x) (dev->heartbeat_count % (x) == 0)
-
- if (!dev->panic)
- {
- if (timeout_seconds(TIMER_INTERVAL_KEEPALIVE_PACKET))
- {
- struct request *req = kmalloc(sizeof(struct request), GFP_ATOMIC );
- // send keepalive
- if (req)
- {
- dnbd3_cmd_to_priv(req, CMD_KEEPALIVE);
- list_add_tail(&req->queuelist, &dev->request_queue_send);
- wake_up(&dev->process_queue_send);
- }
- else
- {
- debug_dev("ERROR: Couldn't create keepalive request.");
- }
- }
- if ((dev->heartbeat_count > STARTUP_MODE_DURATION && timeout_seconds(TIMER_INTERVAL_PROBE_NORMAL))
- || (dev->heartbeat_count <= STARTUP_MODE_DURATION && timeout_seconds(TIMER_INTERVAL_PROBE_STARTUP)))
- {
- // Normal discovery
- dev->discover = 1;
- wake_up(&dev->process_queue_discover);
+ unsigned long irqflags;
+
+ if (!device_active(dev) || dnbd3_flag_taken(dev->connection_lock))
+ return; // device not active anymore, or just about to switch
+ if (!dnbd3_flag_get(dev->discover_running))
+ return; // Already busy
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
+ cancel_delayed_work(&dev->discover_work);
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
+ dnbd3_internal_discover(dev);
+ dev->discover_count++;
+ // Re-queueing logic
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
+ if (device_active(dev)) {
+ mod_delayed_work(system_freezable_power_efficient_wq,
+ &dev->discover_work, dev->discover_interval * HZ);
+ if (dev->discover_interval < TIMER_INTERVAL_PROBE_MAX
+ && dev->discover_count > DISCOVER_STARTUP_PHASE_COUNT) {
+ dev->discover_interval += 2;
}
}
- else if (timeout_seconds(TIMER_INTERVAL_PROBE_PANIC))
- {
- // Panic discovery
- dev->discover = 1;
- wake_up(&dev->process_queue_discover);
- }
-
- dev->hb_timer.expires = jiffies + HZ;
-
- ++dev->heartbeat_count;
- add_timer(&dev->hb_timer);
-#undef timeout_seconds
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
+ dnbd3_flag_reset(dev->discover_running);
}
-int dnbd3_net_discover(void *data)
+/**
+ * Discovery. Probe all (or some) known alt servers,
+ * and initiate connection switch if appropriate
+ */
+static void dnbd3_internal_discover(dnbd3_device_t *dev)
{
- dnbd3_device_t *dev = data;
- struct sockaddr_in sin4;
- struct sockaddr_in6 sin6;
struct socket *sock, *best_sock = NULL;
+ dnbd3_alt_server_t *alt;
+ struct sockaddr_storage host_compare, best_server;
+ uint16_t remote_version;
+ ktime_t start, end;
+ unsigned long rtt = 0, best_rtt = 0;
+ int i, j, k, isize, fails, rtt_threshold;
+ int do_change = 0;
+ u8 check_order[NUMBER_SERVERS];
+ const bool ready = dev->discover_count > DISCOVER_STARTUP_PHASE_COUNT;
+ const u32 turn = dev->discover_count % DISCOVER_HISTORY_SIZE;
+
+ // Shuffle alt_servers
+ for (i = 0; i < NUMBER_SERVERS; ++i)
+ check_order[i] = i;
- dnbd3_request_t dnbd3_request;
- dnbd3_reply_t dnbd3_reply;
- dnbd3_server_t *alt_server;
- struct msghdr msg;
- struct kvec iov[2];
-
- char *buf, *name;
- serialized_buffer_t *payload;
- uint64_t filesize;
- uint16_t rid;
-
- struct timeval start, end;
- unsigned long rtt, best_rtt = 0;
- unsigned long irqflags;
- int i, j, isize, best_server, current_server;
- int turn = 0;
- int ready = 0, do_change = 0;
- char check_order[NUMBER_SERVERS];
- int mlen;
-
- struct request *last_request = (struct request *)123, *cur_request = (struct request *)456;
-
- struct timeval timeout;
- timeout.tv_sec = SOCKET_TIMEOUT_CLIENT_DISCOVERY;
- timeout.tv_usec = 0;
-
- memset(&sin4, 0, sizeof(sin4));
- memset(&sin6, 0, sizeof(sin6));
-
- init_msghdr(msg);
+ for (i = 0; i < NUMBER_SERVERS; ++i) {
+ j = get_random_u32() % NUMBER_SERVERS;
+ if (j != i) {
+ int tmp = check_order[i];
- buf = kmalloc(4096, GFP_KERNEL);
- if (!buf)
- {
- debug_dev("FATAL: Kmalloc failed (discover)");
- return -1;
+ check_order[i] = check_order[j];
+ check_order[j] = tmp;
+ }
}
- payload = (serialized_buffer_t *)buf; // Reuse this buffer to save kernel mem
- dnbd3_request.magic = dnbd3_packet_magic;
+ best_server.ss_family = 0;
+ best_rtt = RTT_UNREACHABLE;
- for (i = 0; i < NUMBER_SERVERS; ++i) {
- check_order[i] = i;
- }
-
- for (;;)
- {
- wait_event_interruptible(dev->process_queue_discover,
- kthread_should_stop() || dev->discover || dev->thread_discover == NULL);
+ if (!ready || dev->panic)
+ isize = NUMBER_SERVERS;
+ else
+ isize = 3;
- if (kthread_should_stop() || dev->imgname == NULL || dev->thread_discover == NULL )
+ for (j = 0; j < NUMBER_SERVERS; ++j) {
+ if (!device_active(dev))
break;
+ i = check_order[j];
+ mutex_lock(&dev->alt_servers_lock);
+ host_compare = dev->alt_servers[i].host;
+ fails = dev->alt_servers[i].failures;
+ mutex_unlock(&dev->alt_servers_lock);
+ if (host_compare.ss_family == 0)
+ continue; // Empty slot
+ // Reduced probability for hosts that have been unreachable
+ if (!dev->panic && fails > 50 && (get_random_u32() % 4) != 0)
+ continue; // If not in panic mode, skip server if it failed too many times
+ if (isize-- <= 0 && !is_same_server(&dev->cur_server.host, &host_compare))
+ continue; // Only test isize servers plus current server
+
+ // Initialize socket and connect
+ sock = NULL;
+ if (dnbd3_connect(dev, &host_compare, &sock) != 0)
+ goto error;
- if (!dev->discover)
- continue;
- dev->discover = 0;
-
- if (dev->reported_size < 4096)
- continue;
-
- // Check if the list of alt servers needs to be updated and do so if necessary
- if (dev->new_servers_num)
- {
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- for (i = 0; i < dev->new_servers_num; ++i)
- {
- if (dev->new_servers[i].host.type != HOST_IP4 && dev->new_servers[i].host.type != HOST_IP6) // Invalid entry?
- continue;
- alt_server = get_existing_server(&dev->new_servers[i], dev);
- if (alt_server != NULL ) // Server already known
- {
- if (dev->new_servers[i].failures == 1)
- {
- // REMOVE request
- if (alt_server->host.type == HOST_IP4)
- debug_dev_va("Removing alt server %pI4", alt_server->host.addr);
- else
- debug_dev_va("Removing alt server %pI6", alt_server->host.addr);
- alt_server->host.type = 0;
- continue;
- }
- // ADD, so just reset fail counter
- alt_server->failures = 0;
- continue;
- }
- if (dev->new_servers[i].failures == 1) // REMOVE, but server is not in list anyways
- continue;
- alt_server = get_free_alt_server(dev);
- if (alt_server == NULL ) // All NUMBER_SERVERS slots are taken, ignore entry
- continue;
- // Add new server entry
- alt_server->host = dev->new_servers[i].host;
- if (alt_server->host.type == HOST_IP4)
- debug_dev_va("Adding alt server %pI4", alt_server->host.addr);
- else
- debug_dev_va("Adding alt server %pI6", alt_server->host.addr);
- alt_server->rtts[0] = alt_server->rtts[1] = alt_server->rtts[2] = alt_server->rtts[3] = RTT_UNREACHABLE;
- alt_server->protocol_version = 0;
- alt_server->failures = 0;
- }
- dev->new_servers_num = 0;
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
- }
+ remote_version = 0;
+ if (!dnbd3_execute_handshake(dev, sock, &host_compare, &remote_version, false))
+ goto error;
- current_server = best_server = -1;
- best_rtt = 0xFFFFFFFul;
- if (dev->heartbeat_count < STARTUP_MODE_DURATION || dev->panic)
- {
- isize = NUMBER_SERVERS;
- }
- else
- {
- isize = 3;
- }
- if (NUMBER_SERVERS > isize) {
- for (i = 0; i < isize; ++i) {
- j = ((start.tv_sec >> i) ^ (start.tv_usec >> j)) % NUMBER_SERVERS;
- if (j != i) {
- mlen = check_order[i];
- check_order[i] = check_order[j];
- check_order[j] = mlen;
- }
+ // panic mode, take first responding server
+ if (dev->panic) {
+ dnbd3_dev_info_host(dev, &host_compare, "panic mode, changing to new server\n");
+ if (!dnbd3_flag_get(dev->connection_lock)) {
+ dnbd3_dev_info_host(dev, &host_compare, "...raced, ignoring\n");
+ } else {
+ // Check global flag, a connect might have been in progress
+ if (best_sock != NULL)
+ sock_release(best_sock);
+ set_socket_timeout(sock, false, SOCKET_TIMEOUT_RECV * 1000 + 1000);
+ if (dnbd3_set_primary_connection(dev, sock, &host_compare, remote_version) != 0)
+ sock_release(sock);
+ dnbd3_flag_reset(dev->connection_lock);
+ return;
}
}
- for (j = 0; j < NUMBER_SERVERS; ++j)
- {
- i = check_order[j];
- if (dev->alt_servers[i].host.type == 0) // Empty slot
- continue;
- if (!dev->panic && dev->alt_servers[i].failures > 50 && (start.tv_usec & 7) != 0) // If not in panic mode, skip server if it failed too many times
- continue;
- if (isize-- <= 0 && !is_same_server(&dev->cur_server, &dev->alt_servers[i]))
- continue;
-
- // Initialize socket and connect
- if (dnbd3_sock_create(dev->alt_servers[i].host.type, SOCK_STREAM, IPPROTO_TCP, &sock) < 0)
- {
- debug_alt("ERROR: Couldn't create socket (discover).");
- sock = NULL;
- continue;
- }
- kernel_setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, (char *)&timeout, sizeof(timeout));
- kernel_setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&timeout, sizeof(timeout));
- sock->sk->sk_allocation = GFP_NOIO;
- if (dev->alt_servers[i].host.type == HOST_IP4)
- {
- sin4.sin_family = AF_INET;
- memcpy(&sin4.sin_addr, dev->alt_servers[i].host.addr, 4);
- sin4.sin_port = dev->alt_servers[i].host.port;
- if (kernel_connect(sock, (struct sockaddr *)&sin4, sizeof(sin4), 0) < 0)
- goto error;
- }
- else
- {
- sin6.sin6_family = AF_INET6;
- memcpy(&sin6.sin6_addr, dev->alt_servers[i].host.addr, 16);
- sin6.sin6_port = dev->alt_servers[i].host.port;
- if (kernel_connect(sock, (struct sockaddr *)&sin6, sizeof(sin6), 0) < 0)
- goto error;
- }
+ // actual rtt measurement is just the first block requests and reply
+ start = ktime_get_real();
+ if (!dnbd3_request_test_block(dev, &host_compare, sock))
+ goto error;
+ end = ktime_get_real();
- // Request filesize
- dnbd3_request.cmd = CMD_SELECT_IMAGE;
- iov[0].iov_base = &dnbd3_request;
- iov[0].iov_len = sizeof(dnbd3_request);
- serializer_reset_write(payload);
- serializer_put_uint16(payload, PROTOCOL_VERSION); // DNBD3 protocol version
- serializer_put_string(payload, dev->imgname); // image name
- serializer_put_uint16(payload, dev->rid); // revision id
- serializer_put_uint8(payload, 0); // are we a server? (no!)
- iov[1].iov_base = payload;
- dnbd3_request.size = iov[1].iov_len = serializer_get_written_length(payload);
- fixup_request(dnbd3_request);
- mlen = iov[1].iov_len + sizeof(dnbd3_request);
- if (kernel_sendmsg(sock, &msg, iov, 2, mlen) != mlen)
- error_alt("ERROR: Requesting image size failed.");
-
- // receive net reply
- iov[0].iov_base = &dnbd3_reply;
- iov[0].iov_len = sizeof(dnbd3_reply);
- if (kernel_recvmsg(sock, &msg, iov, 1, sizeof(dnbd3_reply), msg.msg_flags) != sizeof(dnbd3_reply))
- error_alt("ERROR: Receiving image size packet (header) failed (discover).");
- fixup_reply(dnbd3_reply);
- if (dnbd3_reply.magic != dnbd3_packet_magic || dnbd3_reply.cmd != CMD_SELECT_IMAGE || dnbd3_reply.size < 4)
- error_alt("ERROR: Content of image size packet (header) mismatched (discover).");
-
- // receive data
- iov[0].iov_base = payload;
- iov[0].iov_len = dnbd3_reply.size;
- if (kernel_recvmsg(sock, &msg, iov, 1, dnbd3_reply.size, msg.msg_flags) != dnbd3_reply.size)
- error_alt("ERROR: Receiving image size packet (payload) failed (discover).");
- serializer_reset_read(payload, dnbd3_reply.size);
-
- dev->alt_servers[i].protocol_version = serializer_get_uint16(payload);
- if (dev->alt_servers[i].protocol_version < MIN_SUPPORTED_SERVER)
- error_alt_va("ERROR: Server version too old (client: %d, server: %d, min supported: %d).",
- (int)PROTOCOL_VERSION, (int)dev->alt_servers[i].protocol_version, (int)MIN_SUPPORTED_SERVER);
-
- name = serializer_get_string(payload);
- if (name == NULL )
- error_alt("ERROR: Server did not supply an image name (discover).");
-
- if (strcmp(name, dev->imgname) != 0)
- error_alt_va("ERROR: Image name does not match requested one (client: '%s', server: '%s') (discover).",
- dev->imgname, name);
-
- rid = serializer_get_uint16(payload);
- if (rid != dev->rid)
- error_alt_va("ERROR: Server supplied wrong rid (client: '%d', server: '%d') (discover).",
- (int)dev->rid, (int)rid);
-
- filesize = serializer_get_uint64(payload);
- if (filesize != dev->reported_size)
- error_alt_va("ERROR: Reported image size of %llu does not match expected value %llu.(discover).",
- (unsigned long long)filesize, (unsigned long long)dev->reported_size);
-
- // panic mode, take first responding server
- if (dev->panic)
- {
- dev->panic = 0;
- debug_alt("WARN: Panic mode, changing server:");
- if (best_sock != NULL )
- sock_release(best_sock);
- dev->better_sock = sock; // Pass over socket to take a shortcut in *_connect();
- kfree(buf);
- dev->thread_discover = NULL;
- dnbd3_net_disconnect(dev);
- memcpy(&dev->cur_server, &dev->alt_servers[i], sizeof(dev->cur_server));
- dnbd3_net_connect(dev);
- return 0;
- }
+ mutex_lock(&dev->alt_servers_lock);
+ if (is_same_server(&dev->alt_servers[i].host, &host_compare)) {
+ dev->alt_servers[i].protocol_version = remote_version;
+ dev->alt_servers[i].rtts[turn] =
+ (unsigned long)ktime_us_delta(end, start);
- // Request block
- dnbd3_request.cmd = CMD_GET_BLOCK;
- // Do *NOT* pick a random block as it has proven to cause severe
- // cache thrashing on the server
- dnbd3_request.offset = 0;
- dnbd3_request.size = RTT_BLOCK_SIZE;
- fixup_request(dnbd3_request);
- iov[0].iov_base = &dnbd3_request;
- iov[0].iov_len = sizeof(dnbd3_request);
-
- // start rtt measurement
- do_gettimeofday(&start);
-
- if (kernel_sendmsg(sock, &msg, iov, 1, sizeof(dnbd3_request)) <= 0)
- error_alt("ERROR: Requesting test block failed (discover).");
-
- // receive net reply
- iov[0].iov_base = &dnbd3_reply;
- iov[0].iov_len = sizeof(dnbd3_reply);
- if (kernel_recvmsg(sock, &msg, iov, 1, sizeof(dnbd3_reply), msg.msg_flags) != sizeof(dnbd3_reply))
- error_alt("ERROR: Receiving test block header packet failed (discover).");
- fixup_reply(dnbd3_reply);
- if (dnbd3_reply.magic
- != dnbd3_packet_magic|| dnbd3_reply.cmd != CMD_GET_BLOCK || dnbd3_reply.size != RTT_BLOCK_SIZE)
- error_alt_va("ERROR: Unexpected reply to block request: cmd=%d, size=%d (discover).",
- (int)dnbd3_reply.cmd, (int)dnbd3_reply.size);
-
- // receive data
- iov[0].iov_base = buf;
- iov[0].iov_len = RTT_BLOCK_SIZE;
- if (kernel_recvmsg(sock, &msg, iov, 1, dnbd3_reply.size, msg.msg_flags) != RTT_BLOCK_SIZE)
- error_alt("ERROR: Receiving test block payload failed (discover).");
-
- do_gettimeofday(&end); // end rtt measurement
-
- dev->alt_servers[i].rtts[turn] = (unsigned long)((end.tv_sec - start.tv_sec) * 1000000ull
- + (end.tv_usec - start.tv_usec));
-
- rtt = (dev->alt_servers[i].rtts[0] + dev->alt_servers[i].rtts[1] + dev->alt_servers[i].rtts[2]
- + dev->alt_servers[i].rtts[3]) / 4;
-
- if (best_rtt > rtt)
- {
- // This one is better, keep socket open in case we switch
- best_rtt = rtt;
- best_server = i;
- if (best_sock != NULL )
- sock_release(best_sock);
- best_sock = sock;
- sock = NULL;
- }
- else
- {
- // Not better, discard connection
- sock_release(sock);
- sock = NULL;
- }
+ rtt = 0;
- // update cur servers rtt
- if (is_same_server(&dev->cur_server, &dev->alt_servers[i]))
- {
- dev->cur_rtt = rtt;
- current_server = i;
- }
+ for (k = 0; k < DISCOVER_HISTORY_SIZE; ++k)
+ rtt += dev->alt_servers[i].rtts[k];
+ rtt /= DISCOVER_HISTORY_SIZE;
dev->alt_servers[i].failures = 0;
+ if (dev->alt_servers[i].best_count > 1)
+ dev->alt_servers[i].best_count -= 2;
+ }
+ mutex_unlock(&dev->alt_servers_lock);
- continue;
-
- error: ;
- ++dev->alt_servers[i].failures;
+ if (best_rtt > rtt) {
+ // This one is better, keep socket open in case we switch
+ best_rtt = rtt;
+ best_server = host_compare;
+ if (best_sock != NULL)
+ sock_release(best_sock);
+ best_sock = sock;
+ sock = NULL;
+ } else {
+ // Not better, discard connection
sock_release(sock);
sock = NULL;
- dev->alt_servers[i].rtts[turn] = RTT_UNREACHABLE;
- if (is_same_server(&dev->cur_server, &dev->alt_servers[i]))
- {
- dev->cur_rtt = RTT_UNREACHABLE;
- current_server = i;
- }
- continue;
}
- if (dev->panic)
- {
- // After 21 retries, bail out by reporting errors to block layer
- if (PROBE_COUNT_TIMEOUT > 0 && dev->panic_count < 255 && ++dev->panic_count == PROBE_COUNT_TIMEOUT + 1)
- dnbd3_blk_fail_all_requests(dev);
- }
+ // update cur servers rtt
+ if (is_same_server(&dev->cur_server.host, &host_compare))
+ dev->cur_server.rtt = rtt;
- if (best_server == -1 || kthread_should_stop() || dev->thread_discover == NULL ) // No alt server could be reached at all or thread should stop
- {
- if (best_sock != NULL ) // Should never happen actually
- {
- sock_release(best_sock);
- best_sock = NULL;
- }
- continue;
- }
+ continue;
- do_change = ready && best_server != current_server && (start.tv_usec & 3) != 0
- && RTT_THRESHOLD_FACTOR(dev->cur_rtt) > best_rtt + 1500;
-
- if (ready && !do_change) {
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- if (!list_empty(&dev->request_queue_send))
- {
- cur_request = list_entry(dev->request_queue_send.next, struct request, queuelist);
- do_change = (cur_request == last_request);
- if (do_change)
- printk("WARNING: Hung request on %s\n", dev->disk->disk_name);
- }
- else
- {
- cur_request = (struct request *)123;
- }
- last_request = cur_request;
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
+error:
+ if (sock != NULL) {
+ sock_release(sock);
+ sock = NULL;
}
-
- // take server with lowest rtt
- if (do_change)
- {
- printk("INFO: Server %d on %s is faster (%lluµs vs. %lluµs)\n", best_server, dev->disk->disk_name,
- (unsigned long long)best_rtt, (unsigned long long)dev->cur_rtt);
- kfree(buf);
- dev->better_sock = best_sock; // Take shortcut by continuing to use open connection
- dev->thread_discover = NULL;
- dnbd3_net_disconnect(dev);
- memcpy(&dev->cur_server, &dev->alt_servers[best_server], sizeof(dev->cur_server));
- dev->cur_rtt = best_rtt;
- dnbd3_net_connect(dev);
- return 0;
+ mutex_lock(&dev->alt_servers_lock);
+ if (is_same_server(&dev->alt_servers[i].host, &host_compare)) {
+ if (remote_version)
+ dev->alt_servers[i].protocol_version = remote_version;
+ ++dev->alt_servers[i].failures;
+ dev->alt_servers[i].rtts[turn] = RTT_UNREACHABLE;
+ if (dev->alt_servers[i].best_count > 2)
+ dev->alt_servers[i].best_count -= 3;
}
-
- // Clean up connection that was held open for quicker server switch
- if (best_sock != NULL )
- {
- sock_release(best_sock);
- best_sock = NULL;
+ mutex_unlock(&dev->alt_servers_lock);
+ if (is_same_server(&dev->cur_server.host, &host_compare))
+ dev->cur_server.rtt = RTT_UNREACHABLE;
+ } // END - for loop over alt_servers
+
+ if (best_server.ss_family == 0) {
+ // No alt server could be reached
+ ASSERT(!best_sock);
+ if (dev->panic) {
+ if (dev->panic_count < 255)
+ dev->panic_count++;
+ // If probe timeout is set, report error to block layer
+ if (PROBE_COUNT_TIMEOUT > 0 && dev->panic_count == PROBE_COUNT_TIMEOUT + 1)
+ dnbd3_blk_fail_all_requests(dev);
}
+ return;
+ }
- if (!ready || (start.tv_usec & 15) != 0)
- turn = (turn + 1) % 4;
- if (turn == 2) // Set ready when we only have 2 of 4 measurements for quicker load balancing
- ready = 1;
-
+ // If best server was repeatedly measured best, lower the switching threshold more
+ mutex_lock(&dev->alt_servers_lock);
+ alt = get_existing_alt_from_addr(&best_server, dev);
+ if (alt != NULL) {
+ if (alt->best_count < 178)
+ alt->best_count += 3;
+ rtt_threshold = 1800 - (alt->best_count * 10);
+ remote_version = alt->protocol_version;
+ } else {
+ rtt_threshold = 1800;
+ remote_version = 0;
}
- kfree(buf);
- return 0;
+ mutex_unlock(&dev->alt_servers_lock);
+
+ do_change = ready && !is_same_server(&best_server, &dev->cur_server.host)
+ && RTT_THRESHOLD_FACTOR(dev->cur_server.rtt) > best_rtt + rtt_threshold;
+
+ // take server with lowest rtt
+ // if a (dis)connect is already in progress, we do nothing, this is not panic mode
+ if (do_change && device_active(dev) && dnbd3_flag_get(dev->connection_lock)) {
+ dnbd3_dev_info_cur(dev, "server %pISpc is faster (%lluµs vs. %lluµs)\n",
+ &best_server,
+ (unsigned long long)best_rtt, (unsigned long long)dev->cur_server.rtt);
+ set_socket_timeout(best_sock, false, // recv
+ MAX(best_rtt / 1000, SOCKET_TIMEOUT_RECV * 1000) + 500);
+ set_socket_timeout(best_sock, true, // send
+ MAX(best_rtt / 1000, SOCKET_TIMEOUT_SEND * 1000) + 500);
+ if (dnbd3_set_primary_connection(dev, best_sock, &best_server, remote_version) != 0)
+ sock_release(best_sock);
+ dnbd3_flag_reset(dev->connection_lock);
+ return;
+ }
+
+ // Clean up connection that was held open for quicker server switch
+ if (best_sock != NULL)
+ sock_release(best_sock);
}
-int dnbd3_net_send(void *data)
+/**
+ * Worker for sending pending requests. This will be triggered whenever
+ * we get a new request from the block layer. The worker will then
+ * work through all the requests in the send queue, request them from
+ * the server, and return again.
+ */
+static void dnbd3_send_workfn(struct work_struct *work)
{
- dnbd3_device_t *dev = data;
- struct request *blk_request, *tmp_request;
-
- dnbd3_request_t dnbd3_request;
- struct msghdr msg;
- struct kvec iov;
-
+ dnbd3_device_t *dev = container_of(work, dnbd3_device_t, send_work);
+ struct request *blk_request;
+ struct dnbd3_cmd *cmd;
unsigned long irqflags;
- init_msghdr(msg);
-
- dnbd3_request.magic = dnbd3_packet_magic;
-
- set_user_nice(current, -20);
-
- // move already sent requests to request_queue_send again
- while (!list_empty(&dev->request_queue_receive))
- {
- printk("WARN: Request queue was not empty on %s\n", dev->disk->disk_name);
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- list_for_each_entry_safe(blk_request, tmp_request, &dev->request_queue_receive, queuelist)
- {
- list_del_init(&blk_request->queuelist);
- list_add(&blk_request->queuelist, &dev->request_queue_send);
- }
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
- }
-
- for (;;)
- {
- wait_event_interruptible(dev->process_queue_send, kthread_should_stop() || !list_empty(&dev->request_queue_send));
-
- if (kthread_should_stop())
+ mutex_lock(&dev->send_mutex);
+ while (dev->sock && device_active(dev)) {
+ // extract next block request
+ spin_lock_irqsave(&dev->send_queue_lock, irqflags);
+ if (list_empty(&dev->send_queue)) {
+ spin_unlock_irqrestore(&dev->send_queue_lock, irqflags);
break;
-
- // extract block request
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- if (list_empty(&dev->request_queue_send))
- {
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
- continue;
}
- blk_request = list_entry(dev->request_queue_send.next, struct request, queuelist);
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
- // what to do?
- switch (dnbd3_req_op(blk_request))
- {
- case DNBD3_DEV_READ:
- dnbd3_request.cmd = CMD_GET_BLOCK;
- dnbd3_request.offset = blk_rq_pos(blk_request) << 9; // *512
- dnbd3_request.size = blk_rq_bytes(blk_request); // bytes left to complete entire request
- // enqueue request to request_queue_receive
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- list_del_init(&blk_request->queuelist);
- list_add_tail(&blk_request->queuelist, &dev->request_queue_receive);
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
- break;
- case DNBD3_REQ_OP_SPECIAL:
- dnbd3_request.cmd = dnbd3_priv_to_cmd(blk_request);
- dnbd3_request.size = 0;
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- list_del_init(&blk_request->queuelist);
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
+ blk_request = list_entry(dev->send_queue.next, struct request, queuelist);
+ list_del_init(&blk_request->queuelist);
+ spin_unlock_irqrestore(&dev->send_queue_lock, irqflags);
+ // append to receive queue
+ spin_lock_irqsave(&dev->recv_queue_lock, irqflags);
+ list_add_tail(&blk_request->queuelist, &dev->recv_queue);
+ spin_unlock_irqrestore(&dev->recv_queue_lock, irqflags);
+
+ cmd = blk_mq_rq_to_pdu(blk_request);
+ if (!dnbd3_send_request(dev->sock, CMD_GET_BLOCK, cmd->handle,
+ blk_rq_pos(blk_request) << 9 /* sectors */, blk_rq_bytes(blk_request))) {
+ if (!dnbd3_flag_taken(dev->connection_lock)) {
+ dnbd3_dev_err_cur(dev, "connection to server lost (send)\n");
+ dnbd3_start_discover(dev, true);
+ }
break;
-
- default:
- printk("ERROR: Unknown command (send %u %u)\n", (int)blk_request->cmd_flags, (int)dnbd3_req_op(blk_request));
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- list_del_init(&blk_request->queuelist);
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
- continue;
}
-
- // send net request
- dnbd3_request.handle = (uint64_t)(uintptr_t)blk_request; // Double cast to prevent warning on 32bit
- fixup_request(dnbd3_request);
- iov.iov_base = &dnbd3_request;
- iov.iov_len = sizeof(dnbd3_request);
- if (kernel_sendmsg(dev->sock, &msg, &iov, 1, sizeof(dnbd3_request)) != sizeof(dnbd3_request))
- {
- debug_dev("ERROR: Connection to server lost (send)");
- goto error;
- }
- wake_up(&dev->process_queue_receive);
}
-
- dev->thread_send = NULL;
- return 0;
-
- error: ;
- if (dev->sock)
- kernel_sock_shutdown(dev->sock, SHUT_RDWR);
- if (!dev->disconnecting)
- {
- dev->panic = 1;
- dev->discover = 1;
- wake_up(&dev->process_queue_discover);
- }
- dev->thread_send = NULL;
- return -1;
+ mutex_unlock(&dev->send_mutex);
}
-int dnbd3_net_receive(void *data)
+/**
+ * The receive workfn stays active for as long as the connection to a server
+ * lasts, i.e. it only gets restarted when we switch to a new server.
+ */
+static void dnbd3_recv_workfn(struct work_struct *work)
{
- dnbd3_device_t *dev = data;
- struct request *blk_request, *tmp_request, *received_request;
-
- dnbd3_reply_t dnbd3_reply;
- struct msghdr msg;
- struct kvec iov;
+ dnbd3_device_t *dev = container_of(work, dnbd3_device_t, recv_work);
+ struct request *blk_request;
+ struct request *rq_iter;
+ struct dnbd3_cmd *cmd;
+ dnbd3_reply_t reply_hdr;
struct req_iterator iter;
struct bio_vec bvec_inst;
struct bio_vec *bvec = &bvec_inst;
+ struct msghdr msg = { .msg_flags = MSG_NOSIGNAL | MSG_WAITALL };
+ struct kvec iov;
void *kaddr;
unsigned long irqflags;
- sigset_t blocked, oldset;
uint16_t rid;
- unsigned long int recv_timeout = jiffies;
-
- int count, remaining, ret;
-
- init_msghdr(msg);
- set_user_nice(current, -20);
+ int remaining;
+ int ret;
- while (!kthread_should_stop())
- {
+ mutex_lock(&dev->recv_mutex);
+ while (dev->sock) {
// receive net reply
- iov.iov_base = &dnbd3_reply;
- iov.iov_len = sizeof(dnbd3_reply);
- ret = kernel_recvmsg(dev->sock, &msg, &iov, 1, sizeof(dnbd3_reply), msg.msg_flags);
- if (ret == -EAGAIN)
- {
- if (jiffies < recv_timeout) recv_timeout = jiffies; // Handle overflow
- if ((jiffies - recv_timeout) / HZ > SOCKET_KEEPALIVE_TIMEOUT)
- error_dev_va("ERROR: Receive timeout reached (%d of %d secs).", (int)((jiffies - recv_timeout) / HZ), (int)SOCKET_KEEPALIVE_TIMEOUT);
- continue;
+ ret = dnbd3_recv_reply(dev->sock, &reply_hdr);
+ if (ret == 0) {
+ /* have not received any data, but remote peer is shutdown properly */
+ dnbd3_dev_dbg_cur(dev, "remote peer has performed an orderly shutdown\n");
+ goto out_unlock;
+ } else if (ret < 0) {
+ if (ret == -EAGAIN) {
+ if (!dnbd3_flag_taken(dev->connection_lock))
+ dnbd3_dev_err_cur(dev, "receive timeout reached\n");
+ } else {
+ /* for all errors other than -EAGAIN, print errno */
+ if (!dnbd3_flag_taken(dev->connection_lock))
+ dnbd3_dev_err_cur(dev, "connection to server lost (receive, errno=%d)\n", ret);
+ }
+ goto out_unlock;
}
- if (ret <= 0)
- error_dev("ERROR: Connection to server lost (receive)");
- if (ret != sizeof(dnbd3_reply))
- error_dev("ERROR: Recv msg header.");
- fixup_reply(dnbd3_reply);
- // check error
- if (dnbd3_reply.magic != dnbd3_packet_magic)
- error_dev("ERROR: Wrong packet magic (Receive).");
- if (dnbd3_reply.cmd == 0)
- error_dev("ERROR: Command was 0 (Receive).");
+ /* check if arrived data is valid */
+ if (ret != sizeof(reply_hdr)) {
+ if (!dnbd3_flag_taken(dev->connection_lock))
+ dnbd3_dev_err_cur(dev, "recv partial msg header (%d/%d bytes)\n",
+ ret, (int)sizeof(reply_hdr));
+ goto out_unlock;
+ }
- // Update timeout
- recv_timeout = jiffies;
+ // check error
+ if (reply_hdr.magic != dnbd3_packet_magic) {
+ dnbd3_dev_err_cur(dev, "wrong packet magic (receive)\n");
+ goto out_unlock;
+ }
// what to do?
- switch (dnbd3_reply.cmd)
- {
+ switch (reply_hdr.cmd) {
case CMD_GET_BLOCK:
// search for replied request in queue
blk_request = NULL;
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- list_for_each_entry_safe(received_request, tmp_request, &dev->request_queue_receive, queuelist)
- {
- if ((uint64_t)(uintptr_t)received_request == dnbd3_reply.handle) // Double cast to prevent warning on 32bit
- {
- blk_request = received_request;
+ spin_lock_irqsave(&dev->recv_queue_lock, irqflags);
+ list_for_each_entry(rq_iter, &dev->recv_queue, queuelist) {
+ cmd = blk_mq_rq_to_pdu(rq_iter);
+ if (cmd->handle == reply_hdr.handle) {
+ blk_request = rq_iter;
+ list_del_init(&blk_request->queuelist);
break;
}
}
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
- if (blk_request == NULL )
- error_dev_va("ERROR: Received block data for unrequested handle (%llu: %llu).\n",
- (unsigned long long)dnbd3_reply.handle, (unsigned long long)dnbd3_reply.size);
+ spin_unlock_irqrestore(&dev->recv_queue_lock, irqflags);
+ if (blk_request == NULL) {
+ dnbd3_dev_err_cur(dev, "received block data for unrequested handle (%llx: len=%llu)\n",
+ reply_hdr.handle,
+ (u64)reply_hdr.size);
+ goto out_unlock;
+ }
// receive data and answer to block layer
#if LINUX_VERSION_CODE >= KERNEL_VERSION(3, 14, 0)
- rq_for_each_segment(bvec_inst, blk_request, iter)
+ rq_for_each_segment(bvec_inst, blk_request, iter) {
#else
- rq_for_each_segment(bvec, blk_request, iter)
+ rq_for_each_segment(bvec, blk_request, iter) {
#endif
- {
- siginitsetinv(&blocked, sigmask(SIGKILL));
- sigprocmask(SIG_SETMASK, &blocked, &oldset);
-
kaddr = kmap(bvec->bv_page) + bvec->bv_offset;
iov.iov_base = kaddr;
iov.iov_len = bvec->bv_len;
- if (kernel_recvmsg(dev->sock, &msg, &iov, 1, bvec->bv_len, msg.msg_flags) != bvec->bv_len)
- {
- kunmap(bvec->bv_page);
- sigprocmask(SIG_SETMASK, &oldset, NULL );
- error_dev("ERROR: Receiving from net to block layer.");
- }
+ ret = kernel_recvmsg(dev->sock, &msg, &iov, 1, bvec->bv_len, msg.msg_flags);
kunmap(bvec->bv_page);
-
- sigprocmask(SIG_SETMASK, &oldset, NULL );
+ if (ret != bvec->bv_len) {
+ if (ret == 0) {
+ /* have not received any data, but remote peer is shutdown properly */
+ dnbd3_dev_dbg_cur(
+ dev, "remote peer has performed an orderly shutdown\n");
+ } else if (ret < 0) {
+ if (!dnbd3_flag_taken(dev->connection_lock))
+ dnbd3_dev_err_cur(dev,
+ "disconnect: receiving from net to block layer\n");
+ } else {
+ if (!dnbd3_flag_taken(dev->connection_lock))
+ dnbd3_dev_err_cur(dev,
+ "receiving from net to block layer (%d bytes)\n", ret);
+ }
+ // Requeue request
+ spin_lock_irqsave(&dev->send_queue_lock, irqflags);
+ list_add(&blk_request->queuelist, &dev->send_queue);
+ spin_unlock_irqrestore(&dev->send_queue_lock, irqflags);
+ goto out_unlock;
+ }
}
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- list_del_init(&blk_request->queuelist);
- __blk_end_request_all(blk_request, 0);
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
- continue;
+ blk_mq_end_request(blk_request, BLK_STS_OK);
+ break;
case CMD_GET_SERVERS:
- if (!dev->use_server_provided_alts)
- {
- remaining = dnbd3_reply.size;
- goto consume_payload;
- }
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- dev->new_servers_num = 0;
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
- count = MIN(NUMBER_SERVERS, dnbd3_reply.size / sizeof(dnbd3_server_entry_t));
-
- if (count != 0)
- {
- iov.iov_base = dev->new_servers;
- iov.iov_len = count * sizeof(dnbd3_server_entry_t);
- if (kernel_recvmsg(dev->sock, &msg, &iov, 1, (count * sizeof(dnbd3_server_entry_t)), msg.msg_flags)
- != (count * sizeof(dnbd3_server_entry_t)))
- error_dev("ERROR: Recv CMD_GET_SERVERS payload.");
- spin_lock_irqsave(&dev->blk_lock, irqflags);
- dev->new_servers_num = count;
- spin_unlock_irqrestore(&dev->blk_lock, irqflags);
- }
- // If there were more servers than accepted, remove the remaining data from the socket buffer
- remaining = dnbd3_reply.size - (count * sizeof(dnbd3_server_entry_t));
- consume_payload: while (remaining > 0)
- {
- count = MIN(sizeof(dnbd3_reply), remaining); // Abuse the reply struct as the receive buffer
- iov.iov_base = &dnbd3_reply;
- iov.iov_len = count;
- ret = kernel_recvmsg(dev->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags);
- if (ret <= 0)
- error_dev("ERROR: Recv additional payload from CMD_GET_SERVERS.");
- remaining -= ret;
+ remaining = reply_hdr.size;
+ if (dev->use_server_provided_alts) {
+ dnbd3_server_entry_t new_server;
+
+ while (remaining >= sizeof(dnbd3_server_entry_t)) {
+ if (dnbd3_recv_bytes(dev->sock, &new_server, sizeof(new_server))
+ != sizeof(new_server)) {
+ if (!dnbd3_flag_taken(dev->connection_lock))
+ dnbd3_dev_err_cur(dev, "recv CMD_GET_SERVERS payload\n");
+ goto out_unlock;
+ }
+ // TODO: Log
+ if (new_server.failures == 0) { // ADD
+ dnbd3_add_server(dev, &new_server.host);
+ } else { // REM
+ dnbd3_rem_server(dev, &new_server.host);
+ }
+ remaining -= sizeof(new_server);
+ }
}
- continue;
+ if (!dnbd3_drain_socket(dev, dev->sock, remaining))
+ goto out_unlock;
+ break;
case CMD_LATEST_RID:
- if (dnbd3_reply.size != 2)
- {
- printk("ERROR: CMD_LATEST_RID.size != 2.\n");
+ if (reply_hdr.size < 2) {
+ dev_err(dnbd3_device_to_dev(dev), "CMD_LATEST_RID.size < 2\n");
continue;
}
- iov.iov_base = &rid;
- iov.iov_len = sizeof(rid);
- if (kernel_recvmsg(dev->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags) <= 0)
- {
- printk("ERROR: Could not receive CMD_LATEST_RID payload.\n");
- }
- else
- {
+ if (dnbd3_recv_bytes(dev->sock, &rid, 2) != 2) {
+ if (!dnbd3_flag_taken(dev->connection_lock))
+ dev_err(dnbd3_device_to_dev(dev), "could not receive CMD_LATEST_RID payload\n");
+ } else {
rid = net_order_16(rid);
- printk("Latest rid of %s is %d (currently using %d)\n", dev->imgname, (int)rid, (int)dev->rid);
+ dnbd3_dev_info_cur(dev, "latest rid of %s is %d (currently using %d)\n",
+ dev->imgname, (int)rid, (int)dev->rid);
dev->update_available = (rid > dev->rid ? 1 : 0);
}
+ if (reply_hdr.size > 2)
+ dnbd3_drain_socket(dev, dev->sock, reply_hdr.size - 2);
continue;
case CMD_KEEPALIVE:
- if (dnbd3_reply.size != 0)
- printk("ERROR: keep alive packet with payload.\n");
+ if (reply_hdr.size != 0) {
+ dev_dbg(dnbd3_device_to_dev(dev), "keep alive packet with payload\n");
+ dnbd3_drain_socket(dev, dev->sock, reply_hdr.size);
+ }
continue;
default:
- printk("ERROR: Unknown command (Receive)\n");
- continue;
+ dev_err(dnbd3_device_to_dev(dev), "unknown command: %d (receive), aborting connection\n", (int)reply_hdr.cmd);
+ goto out_unlock;
+ }
+ }
+out_unlock:
+ // This will check if we actually still need a new connection
+ dnbd3_start_discover(dev, true);
+ mutex_unlock(&dev->recv_mutex);
+}
+/**
+ * Set send or receive timeout of given socket
+ */
+static void set_socket_timeout(struct socket *sock, bool set_send, int timeout_ms)
+{
+#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 1, 0)
+ int opt = set_send ? SO_SNDTIMEO_NEW : SO_RCVTIMEO_NEW;
+ struct __kernel_sock_timeval timeout;
+#else
+ int opt = set_send ? SO_SNDTIMEO : SO_RCVTIMEO;
+ struct timeval timeout;
+#endif
+#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 9, 0)
+ sockptr_t timeout_ptr = KERNEL_SOCKPTR(&timeout);
+#else
+ char *timeout_ptr = (char *)&timeout;
+#endif
+
+ timeout.tv_sec = timeout_ms / 1000;
+ timeout.tv_usec = (timeout_ms % 1000) * 1000;
+ sock_setsockopt(sock, SOL_SOCKET, opt, timeout_ptr, sizeof(timeout));
+}
+
+static int dnbd3_connect(dnbd3_device_t *dev, struct sockaddr_storage *addr, struct socket **sock_out)
+{
+ ktime_t start;
+ int ret, connect_time_ms;
+ struct socket *sock;
+ int retries = 4;
+ const int addrlen = addr->ss_family == AF_INET ? sizeof(struct sockaddr_in)
+ : sizeof(struct sockaddr_in6);
+
+#if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 2, 0)
+ ret = sock_create_kern(&init_net, addr->ss_family, SOCK_STREAM,
+ IPPROTO_TCP, &sock);
+#else
+ ret = sock_create_kern(addr->ss_family, SOCK_STREAM,
+ IPPROTO_TCP, &sock);
+#endif
+ if (ret < 0) {
+ dev_err(dnbd3_device_to_dev(dev), "couldn't create socket: %d\n", ret);
+ return ret;
+ }
+
+ /* Only one retry, TCP no delay */
+#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 8, 0)
+ tcp_sock_set_syncnt(sock->sk, 1);
+ tcp_sock_set_nodelay(sock->sk);
+ /* because of our aggressive timeouts, this is pointless */
+ sock_no_linger(sock->sk);
+#else
+ /* add legacy version of this, but ignore others as they're not that important */
+ ret = 1;
+ kernel_setsockopt(sock, IPPROTO_TCP, TCP_SYNCNT,
+ (char *)&ret, sizeof(ret));
+#endif
+ /* allow this socket to use reserved mem (vm.mem_free_kbytes) */
+ sk_set_memalloc(sock->sk);
+ sock->sk->sk_allocation = GFP_NOIO;
+
+ if (dev->panic && dev->panic_count > 1) {
+ /* in panic mode for some time, start increasing timeouts */
+ connect_time_ms = dev->panic_count * 1000;
+ } else {
+ /* otherwise, use 2*RTT of current server */
+ connect_time_ms = dev->cur_server.rtt * 2 / 1000;
+ }
+ /* but obey a minimal configurable value, and maximum sanity check */
+ if (connect_time_ms < SOCKET_TIMEOUT_SEND * 1000)
+ connect_time_ms = SOCKET_TIMEOUT_SEND * 1000;
+ else if (connect_time_ms > 60000)
+ connect_time_ms = 60000;
+ set_socket_timeout(sock, false, connect_time_ms); // recv
+ set_socket_timeout(sock, true, connect_time_ms); // send
+ start = ktime_get_real();
+ while (--retries > 0) {
+ ret = kernel_connect(sock, (struct sockaddr *)addr, addrlen, 0);
+ connect_time_ms = (int)ktime_ms_delta(ktime_get_real(), start);
+ if (connect_time_ms > 2 * SOCKET_TIMEOUT_SEND * 1000) {
+ /* Either I'm losing my mind or there was a specific build of kernel
+ * 5.x where SO_RCVTIMEO didn't affect the connect call above, so
+ * this function would hang for over a minute for unreachable hosts.
+ * Leave in this debug check for twice the configured timeout
+ */
+ dnbd3_dev_dbg_host(dev, addr, "connect: call took %dms\n",
+ connect_time_ms);
}
+ if (ret != 0) {
+ if (ret == -EINTR)
+ dnbd3_dev_dbg_host(dev, addr, "connect: interrupted system call (blocked %dms)\n",
+ connect_time_ms);
+ else
+ dnbd3_dev_dbg_host(dev, addr, "connect: failed (%d, blocked %dms)\n",
+ ret, connect_time_ms);
+ goto error;
+ }
+ *sock_out = sock;
+ return 0;
}
+error:
+ sock_release(sock);
+ return ret < 0 ? ret : -EIO;
+}
- printk("dnbd3_net_receive terminated normally.\n");
- dev->thread_receive = NULL;
- return 0;
+#define dnbd3_err_dbg_host(...) do { \
+ if (dev->panic || dev->sock == NULL) \
+ dnbd3_dev_err_host(__VA_ARGS__); \
+ else \
+ dnbd3_dev_dbg_host(__VA_ARGS__); \
+} while (0)
+
+/**
+ * Execute protocol handshake on a newly connected socket.
+ * If this is the initial connection to any server, ie. we're being called
+ * through the initial ioctl() to open a device, we'll store the rid, filesize
+ * etc. in the dev struct., otherwise, this is a potential switch to another
+ * server, so we validate the filesize, rid, name against what we expect.
+ * The server's protocol version is returned in 'remote_version'
+ */
+static bool dnbd3_execute_handshake(dnbd3_device_t *dev, struct socket *sock,
+ struct sockaddr_storage *addr, uint16_t *remote_version, bool copy_data)
+{
+ unsigned long irqflags;
+ const char *name;
+ uint64_t filesize;
+ int mlen;
+ uint16_t rid;
+ struct msghdr msg = { .msg_flags = MSG_NOSIGNAL | MSG_WAITALL };
+ struct kvec iov[2];
+ serialized_buffer_t *payload;
+ dnbd3_reply_t reply_hdr;
+ dnbd3_request_t request_hdr = { .magic = dnbd3_packet_magic };
+
+ payload = kmalloc(sizeof(*payload), GFP_KERNEL);
+ if (payload == NULL)
+ goto error;
+
+ if (copy_data && device_active(dev))
+ dev_warn(dnbd3_device_to_dev(dev), "Called handshake function with copy_data enabled when reported_size is not zero\n");
+
+ // Request filesize
+ request_hdr.cmd = CMD_SELECT_IMAGE;
+ iov[0].iov_base = &request_hdr;
+ iov[0].iov_len = sizeof(request_hdr);
+ serializer_reset_write(payload);
+ serializer_put_uint16(payload, PROTOCOL_VERSION); // DNBD3 protocol version
+ serializer_put_string(payload, dev->imgname); // image name
+ serializer_put_uint16(payload, dev->rid); // revision id
+ serializer_put_uint8(payload, 0); // are we a server? (no!)
+ iov[1].iov_base = payload;
+ request_hdr.size = iov[1].iov_len = serializer_get_written_length(payload);
+ fixup_request(request_hdr);
+ mlen = iov[0].iov_len + iov[1].iov_len;
+ if (kernel_sendmsg(sock, &msg, iov, 2, mlen) != mlen) {
+ dnbd3_err_dbg_host(dev, addr, "requesting image size failed\n");
+ goto error;
+ }
+
+ // receive net reply
+ if (dnbd3_recv_reply(sock, &reply_hdr) != sizeof(reply_hdr)) {
+ dnbd3_err_dbg_host(dev, addr, "receiving image size packet (header) failed\n");
+ goto error;
+ }
+ if (reply_hdr.magic != dnbd3_packet_magic
+ || reply_hdr.cmd != CMD_SELECT_IMAGE || reply_hdr.size < 4
+ || reply_hdr.size > sizeof(*payload)) {
+ dnbd3_err_dbg_host(dev, addr,
+ "corrupt CMD_SELECT_IMAGE reply\n");
+ goto error;
+ }
+
+ // receive data
+ iov[0].iov_base = payload;
+ iov[0].iov_len = reply_hdr.size;
+ if (kernel_recvmsg(sock, &msg, iov, 1, reply_hdr.size, msg.msg_flags)
+ != reply_hdr.size) {
+ dnbd3_err_dbg_host(dev, addr,
+ "receiving payload of CMD_SELECT_IMAGE reply failed\n");
+ goto error;
+ }
+ serializer_reset_read(payload, reply_hdr.size);
+
+ *remote_version = serializer_get_uint16(payload);
+ name = serializer_get_string(payload);
+ rid = serializer_get_uint16(payload);
+ filesize = serializer_get_uint64(payload);
+
+ if (*remote_version < MIN_SUPPORTED_SERVER) {
+ dnbd3_err_dbg_host(dev, addr,
+ "server version too old (client: %d, server: %d, min supported: %d)\n",
+ (int)PROTOCOL_VERSION, (int)*remote_version,
+ (int)MIN_SUPPORTED_SERVER);
+ goto error;
+ }
+ if (name == NULL) {
+ dnbd3_err_dbg_host(dev, addr, "server did not supply an image name\n");
+ goto error;
+ }
+ if (rid == 0) {
+ dnbd3_err_dbg_host(dev, addr, "server did not supply a revision id\n");
+ goto error;
+ }
+
+ if (copy_data) {
+ if (filesize < DNBD3_BLOCK_SIZE) {
+ dnbd3_err_dbg_host(dev, addr, "reported size by server is < 4096\n");
+ goto error;
+ }
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
+ if (strlen(dev->imgname) < strlen(name)) {
+ dev->imgname = krealloc(dev->imgname, strlen(name) + 1, GFP_KERNEL);
+ if (dev->imgname == NULL) {
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
+ dnbd3_err_dbg_host(dev, addr, "reallocating buffer for new image name failed\n");
+ goto error;
+ }
+ }
+ strcpy(dev->imgname, name);
+ dev->rid = rid;
+ // store image information
+ dev->reported_size = filesize;
+ dev->update_available = 0;
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
+ set_capacity(dev->disk, dev->reported_size >> 9); /* 512 Byte blocks */
+ dnbd3_dev_dbg_host(dev, addr, "image size: %llu\n", dev->reported_size);
+ } else {
+ /* switching connection, sanity checks */
+ if (rid != dev->rid) {
+ dnbd3_err_dbg_host(dev, addr,
+ "server supplied wrong rid (client: '%d', server: '%d')\n",
+ (int)dev->rid, (int)rid);
+ goto error;
+ }
+
+ if (strcmp(name, dev->imgname) != 0) {
+ dnbd3_err_dbg_host(dev, addr, "server offers image '%s', requested '%s'\n", name, dev->imgname);
+ goto error;
+ }
+
+ if (filesize != dev->reported_size) {
+ dnbd3_err_dbg_host(dev, addr,
+ "reported image size of %llu does not match expected value %llu\n",
+ (unsigned long long)filesize, (unsigned long long)dev->reported_size);
+ goto error;
+ }
+ }
+ kfree(payload);
+ return true;
+
+error:
+ kfree(payload);
+ return false;
+}
+
+static bool dnbd3_send_request(struct socket *sock, u16 cmd, u64 handle, u64 offset, u32 size)
+{
+ struct msghdr msg = { .msg_flags = MSG_NOSIGNAL };
+ dnbd3_request_t request_hdr = {
+ .magic = dnbd3_packet_magic,
+ .cmd = cmd,
+ .size = size,
+ .offset = offset,
+ .handle = handle,
+ };
+ struct kvec iov = { .iov_base = &request_hdr, .iov_len = sizeof(request_hdr) };
+
+ fixup_request(request_hdr);
+ return kernel_sendmsg(sock, &msg, &iov, 1, sizeof(request_hdr)) == sizeof(request_hdr);
+}
+
+/**
+ * Send a request with given cmd type and empty payload.
+ */
+static bool dnbd3_send_empty_request(dnbd3_device_t *dev, u16 cmd)
+{
+ int ret;
+
+ mutex_lock(&dev->send_mutex);
+ ret = dev->sock
+ && dnbd3_send_request(dev->sock, cmd, 0, 0, 0);
+ mutex_unlock(&dev->send_mutex);
+ return ret;
+}
+
+static int dnbd3_recv_bytes(struct socket *sock, void *buffer, size_t count)
+{
+ struct msghdr msg = { .msg_flags = MSG_NOSIGNAL | MSG_WAITALL };
+ struct kvec iov = { .iov_base = buffer, .iov_len = count };
+
+ return kernel_recvmsg(sock, &msg, &iov, 1, count, msg.msg_flags);
+}
+
+static int dnbd3_recv_reply(struct socket *sock, dnbd3_reply_t *reply_hdr)
+{
+ int ret = dnbd3_recv_bytes(sock, reply_hdr, sizeof(*reply_hdr));
+
+ fixup_reply(*reply_hdr);
+ return ret;
+}
+
+static bool dnbd3_drain_socket(dnbd3_device_t *dev, struct socket *sock, int bytes)
+{
+ int ret;
+ struct kvec iov;
+ struct msghdr msg = { .msg_flags = MSG_NOSIGNAL };
+
+ while (bytes > 0) {
+ iov.iov_base = __garbage_mem;
+ iov.iov_len = sizeof(__garbage_mem);
+ ret = kernel_recvmsg(sock, &msg, &iov, 1, MIN(bytes, iov.iov_len), msg.msg_flags);
+ if (ret <= 0) {
+ dnbd3_dev_err_cur(dev, "draining payload failed (ret=%d)\n", ret);
+ return false;
+ }
+ bytes -= ret;
+ }
+ return true;
+}
+
+static bool dnbd3_request_test_block(dnbd3_device_t *dev, struct sockaddr_storage *addr, struct socket *sock)
+{
+ dnbd3_reply_t reply_hdr;
+
+ // Request block
+ if (!dnbd3_send_request(sock, CMD_GET_BLOCK, 0, 0, RTT_BLOCK_SIZE)) {
+ dnbd3_err_dbg_host(dev, addr, "requesting test block failed\n");
+ return false;
+ }
+
+ // receive net reply
+ if (dnbd3_recv_reply(sock, &reply_hdr) != sizeof(reply_hdr)) {
+ dnbd3_err_dbg_host(dev, addr, "receiving test block header packet failed\n");
+ return false;
+ }
+ if (reply_hdr.magic != dnbd3_packet_magic || reply_hdr.cmd != CMD_GET_BLOCK
+ || reply_hdr.size != RTT_BLOCK_SIZE || reply_hdr.handle != 0) {
+ dnbd3_err_dbg_host(dev, addr,
+ "unexpected reply to block request: cmd=%d, size=%d, handle=%llu (discover)\n",
+ (int)reply_hdr.cmd, (int)reply_hdr.size, reply_hdr.handle);
+ return false;
+ }
- error:
+ // receive data
+ return dnbd3_drain_socket(dev, sock, RTT_BLOCK_SIZE);
+}
+#undef dnbd3_err_dbg_host
+
+static void replace_main_socket(dnbd3_device_t *dev, struct socket *sock, struct sockaddr_storage *addr, u16 protocol_version)
+{
+ unsigned long irqflags;
+
+ mutex_lock(&dev->send_mutex);
+ // First, shutdown connection, so receive worker will leave its mainloop
if (dev->sock)
kernel_sock_shutdown(dev->sock, SHUT_RDWR);
- if (!dev->disconnecting)
- {
- dev->panic = 1;
- dev->discover = 1;
- wake_up(&dev->process_queue_discover);
+ mutex_lock(&dev->recv_mutex);
+ // Receive worker is done, get rid of socket and replace
+ if (dev->sock)
+ sock_release(dev->sock);
+ dev->sock = sock;
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
+ if (addr == NULL) {
+ memset(&dev->cur_server, 0, sizeof(dev->cur_server));
+ } else {
+ dev->cur_server.host = *addr;
+ dev->cur_server.rtt = 0;
+ dev->cur_server.protocol_version = protocol_version;
}
- dev->thread_receive = NULL;
- return -1;
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
+ mutex_unlock(&dev->recv_mutex);
+ mutex_unlock(&dev->send_mutex);
}
+static void dnbd3_release_resources(dnbd3_device_t *dev)
+{
+ if (dev->send_wq)
+ destroy_workqueue(dev->send_wq);
+ dev->send_wq = NULL;
+ if (dev->recv_wq)
+ destroy_workqueue(dev->recv_wq);
+ dev->recv_wq = NULL;
+ mutex_destroy(&dev->send_mutex);
+ mutex_destroy(&dev->recv_mutex);
+}
+
+/**
+ * Establish new connection on a dnbd3 device.
+ * Return 0 on success, errno otherwise
+ */
+int dnbd3_new_connection(dnbd3_device_t *dev, struct sockaddr_storage *addr, bool init)
+{
+ unsigned long irqflags;
+ struct socket *sock = NULL;
+ uint16_t proto_version;
+ int ret;
+
+ ASSERT(dnbd3_flag_taken(dev->connection_lock));
+ if (init && device_active(dev)) {
+ dnbd3_dev_err_cur(dev, "device already configured/connected\n");
+ return -EBUSY;
+ }
+ if (!init && !device_active(dev)) {
+ dev_warn(dnbd3_device_to_dev(dev), "connection switch called on unconfigured device\n");
+ return -ENOTCONN;
+ }
+
+ dnbd3_dev_dbg_host(dev, addr, "connecting...\n");
+ ret = dnbd3_connect(dev, addr, &sock);
+ if (ret != 0 || sock == NULL)
+ goto error;
+
+ /* execute the "select image" handshake */
+ // if init is true, reported_size will be set
+ if (!dnbd3_execute_handshake(dev, sock, addr, &proto_version, init)) {
+ ret = -EINVAL;
+ goto error;
+ }
+
+ if (init) {
+ // We're setting up the device for use - allocate resources
+ // Do not goto error before this
+ ASSERT(!dev->send_wq);
+ ASSERT(!dev->recv_wq);
+ mutex_init(&dev->send_mutex);
+ mutex_init(&dev->recv_mutex);
+ // a designated queue for sending, that allows one active task only
+ dev->send_wq = alloc_workqueue("dnbd%d-send",
+ WQ_UNBOUND | WQ_FREEZABLE | WQ_MEM_RECLAIM | WQ_HIGHPRI,
+ 1, dev->index);
+ dev->recv_wq = alloc_workqueue("dnbd%d-recv",
+ WQ_UNBOUND | WQ_FREEZABLE | WQ_MEM_RECLAIM | WQ_HIGHPRI | WQ_CPU_INTENSIVE,
+ 1, dev->index);
+ if (!dev->send_wq || !dev->recv_wq) {
+ ret = -ENOMEM;
+ goto error_dealloc;
+ }
+ }
+
+ set_socket_timeout(sock, false, SOCKET_TIMEOUT_RECV * 1000); // recv
+ dnbd3_set_primary_connection(dev, sock, addr, proto_version);
+ sock = NULL; // In case we ever goto error* after this point
+
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
+ if (init) {
+ dev->discover_count = 0;
+ dev->discover_interval = TIMER_INTERVAL_PROBE_STARTUP;
+ // discovery and keepalive are not critical, use the power efficient queue
+ queue_delayed_work(system_power_efficient_wq, &dev->discover_work,
+ dev->discover_interval * HZ);
+ queue_delayed_work(system_power_efficient_wq, &dev->keepalive_work,
+ KEEPALIVE_INTERVAL * HZ);
+ // but the receiver is performance critical AND runs indefinitely, use the
+ // the cpu intensive queue, as jobs submitted there will not cound towards
+ // the concurrency limit of per-cpu worker threads. It still feels a little
+ // dirty to avoid managing our own thread, but nbd does it too.
+ }
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
+ return 0;
+
+error_dealloc:
+ if (init) {
+ // If anything fails during initialization, free resources again
+ dnbd3_release_resources(dev);
+ }
+error:
+ if (init)
+ dev->reported_size = 0;
+ if (sock)
+ sock_release(sock);
+ return ret < 0 ? ret : -EIO;
+}
+
+void dnbd3_net_work_init(dnbd3_device_t *dev)
+{
+ INIT_WORK(&dev->send_work, dnbd3_send_workfn);
+ INIT_WORK(&dev->recv_work, dnbd3_recv_workfn);
+ INIT_DELAYED_WORK(&dev->discover_work, dnbd3_discover_workfn);
+ INIT_DELAYED_WORK(&dev->keepalive_work, dnbd3_keepalive_workfn);
+}
+
+static int dnbd3_set_primary_connection(dnbd3_device_t *dev, struct socket *sock, struct sockaddr_storage *addr, u16 protocol_version)
+{
+ unsigned long irqflags;
+
+ ASSERT(dnbd3_flag_taken(dev->connection_lock));
+ if (addr->ss_family == 0 || dev->imgname == NULL || sock == NULL) {
+ dnbd3_dev_err_cur(dev, "connect: host, image name or sock not set\n");
+ return -EINVAL;
+ }
+
+ replace_main_socket(dev, sock, addr, protocol_version);
+ spin_lock_irqsave(&dev->blk_lock, irqflags);
+ dev->panic = false;
+ dev->panic_count = 0;
+ dev->discover_interval = TIMER_INTERVAL_PROBE_SWITCH;
+ queue_work(dev->recv_wq, &dev->recv_work);
+ spin_unlock_irqrestore(&dev->blk_lock, irqflags);
+
+ if (dev->use_server_provided_alts)
+ dnbd3_send_empty_request(dev, CMD_GET_SERVERS);
+
+ dnbd3_dev_info_cur(dev, "connection switched\n");
+ dnbd3_blk_requeue_all_requests(dev);
+ return 0;
+}
+
+/**
+ * Disconnect the device, shutting it down.
+ */
+int dnbd3_net_disconnect(dnbd3_device_t *dev)
+{
+ ASSERT(dnbd3_flag_taken(dev->connection_lock));
+ if (!device_active(dev))
+ return -ENOTCONN;
+ dev_dbg(dnbd3_device_to_dev(dev), "disconnecting device ...\n");
+
+ dev->reported_size = 0;
+ /* quickly fail all requests */
+ dnbd3_blk_fail_all_requests(dev);
+ replace_main_socket(dev, NULL, NULL, 0);
+
+ cancel_delayed_work_sync(&dev->keepalive_work);
+ cancel_delayed_work_sync(&dev->discover_work);
+ cancel_work_sync(&dev->send_work);
+ cancel_work_sync(&dev->recv_work);
+
+ dnbd3_blk_fail_all_requests(dev);
+ dnbd3_release_resources(dev);
+ dev_dbg(dnbd3_device_to_dev(dev), "all workers shut down\n");
+ return 0;
+}