summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSimon Rettberg2020-11-19 13:48:14 +0100
committerSimon Rettberg2020-11-19 13:48:14 +0100
commita1caec1f2bebe09f685716f13d7b55f84d8a8145 (patch)
tree5eef7b3f63fa1a240bf76e1fa46859946ac0adb8
parent[KERNEL] add support for Linux kernel 4.19 on Ubuntu 18.04 (GCC 7.5) (diff)
downloaddnbd3-a1caec1f2bebe09f685716f13d7b55f84d8a8145.tar.gz
dnbd3-a1caec1f2bebe09f685716f13d7b55f84d8a8145.tar.xz
dnbd3-a1caec1f2bebe09f685716f13d7b55f84d8a8145.zip
[KERNEL] Fix several connect/disconnect race conditions
Previously disconnect was protected against concurrent calls, but connect wasn't. It was easy to crash the kernel when calling connect and disconnect IOCTLs in a tight loop concurrently. A global lock was introduced to make sure only one caller can change the connection state at a time. dev->connection_lock needs to be aquired when calling dnbd3_net_connect or _disconnect. This atomic_t based locking mechanism should be turned into a mutex in a next step, relying on mutex_trylock for cases where we don't have the cmpxchg-schedule() loop. Along the way it was noticed that the send/receive timeouts don't apply to kernel_connect, which might have been the case in older 3.x kernel versions. A crude workaround using nonblocking connect has been introduced to emulate this, but a clean solution for this is welcomed. Also, devices are now properly closed on module unload.
-rw-r--r--src/kernel/blk.c57
-rw-r--r--src/kernel/dnbd3_main.h3
-rw-r--r--src/kernel/net.c321
3 files changed, 193 insertions, 188 deletions
diff --git a/src/kernel/blk.c b/src/kernel/blk.c
index 69d02d5..43251d3 100644
--- a/src/kernel/blk.c
+++ b/src/kernel/blk.c
@@ -33,6 +33,25 @@
#define dnbd3_req_special(req) \
blk_rq_is_private(req)
+static int dnbd3_close_device(dnbd3_device_t *dev)
+{
+ int result;
+ dnbd3_blk_fail_all_requests(dev);
+ dev->panic = 0;
+ dev->discover = 0;
+ result = dnbd3_net_disconnect(dev);
+ dnbd3_blk_fail_all_requests(dev);
+ blk_mq_freeze_queue(dev->queue);
+ set_capacity(dev->disk, 0);
+ blk_mq_unfreeze_queue(dev->queue);
+ if (dev->imgname)
+ {
+ kfree(dev->imgname);
+ dev->imgname = NULL;
+ }
+ return result;
+}
+
static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int cmd, unsigned long arg)
{
int result = -100;
@@ -45,8 +64,7 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int
dnbd3_server_t *alt_server;
unsigned long irqflags;
int i = 0;
-
- while (dev->disconnecting) { /* do nothing */ }
+ u8 locked = 0;
if (arg != 0)
{
@@ -83,6 +101,12 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int
switch (cmd)
{
case IOCTL_OPEN:
+ if (atomic_cmpxchg(&dev->connection_lock, 0, 1) != 0)
+ {
+ result = -EBUSY;
+ break;
+ }
+ locked = 1;
if (dev->imgname != NULL)
{
result = -EBUSY;
@@ -165,20 +189,22 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int
break;
case IOCTL_CLOSE:
- dnbd3_blk_fail_all_requests(dev);
- result = dnbd3_net_disconnect(dev);
- dnbd3_blk_fail_all_requests(dev);
- blk_mq_freeze_queue(dev->queue);
- set_capacity(dev->disk, 0);
- blk_mq_unfreeze_queue(dev->queue);
- if (dev->imgname)
+ if (atomic_cmpxchg(&dev->connection_lock, 0, 1) != 0)
{
- kfree(dev->imgname);
- dev->imgname = NULL;
+ result = -EBUSY;
+ break;
}
+ locked = 1;
+ result = dnbd3_close_device(dev);
break;
case IOCTL_SWITCH:
+ if (atomic_cmpxchg(&dev->connection_lock, 0, 1) != 0)
+ {
+ result = -EBUSY;
+ break;
+ }
+ locked = 1;
if (dev->imgname == NULL)
{
result = -ENOENT;
@@ -278,6 +304,9 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int
break;
}
+ if (locked)
+ atomic_set(&dev->connection_lock, 0);
+
cleanup_return:
if (msg) kfree(msg);
if (imgname) kfree(imgname);
@@ -354,7 +383,7 @@ int dnbd3_blk_add_device(dnbd3_device_t *dev, int minor)
dev->thread_receive = NULL;
dev->thread_discover = NULL;
dev->discover = 0;
- dev->disconnecting = 0;
+ atomic_set(&dev->connection_lock, 0);
dev->panic = 0;
dev->panic_count = 0;
dev->reported_size = 0;
@@ -432,8 +461,10 @@ out:
int dnbd3_blk_del_device(dnbd3_device_t *dev)
{
+ while (atomic_cmpxchg(&dev->connection_lock, 0, 1) != 0)
+ schedule();
+ dnbd3_close_device(dev);
dnbd3_sysfs_exit(dev);
- dnbd3_net_disconnect(dev);
del_gendisk(dev->disk);
blk_cleanup_queue(dev->queue);
blk_mq_free_tag_set(&dev->tag_set);
diff --git a/src/kernel/dnbd3_main.h b/src/kernel/dnbd3_main.h
index a3c2828..ec8c8cf 100644
--- a/src/kernel/dnbd3_main.h
+++ b/src/kernel/dnbd3_main.h
@@ -62,7 +62,8 @@ typedef struct
dnbd3_server_t alt_servers[NUMBER_SERVERS]; // array of alt servers
int new_servers_num; // number of new alt servers that are waiting to be copied to above array
dnbd3_server_entry_t new_servers[NUMBER_SERVERS]; // pending new alt servers
- uint8_t discover, panic, disconnecting, update_available, panic_count;
+ uint8_t discover, panic, update_available, panic_count;
+ atomic_t connection_lock;
uint8_t use_server_provided_alts;
uint16_t rid;
uint32_t heartbeat_count;
diff --git a/src/kernel/net.c b/src/kernel/net.c
index a0444d2..f460458 100644
--- a/src/kernel/net.c
+++ b/src/kernel/net.c
@@ -74,6 +74,8 @@
#define dnbd3_dev_dbg_host_alt(dev, fmt, ...) __dnbd3_dev_dbg_host((dev), (dev)->alt_servers[i].host, fmt, ##__VA_ARGS__)
#define dnbd3_dev_err_host_alt(dev, fmt, ...) __dnbd3_dev_err_host((dev), (dev)->alt_servers[i].host, fmt, ##__VA_ARGS__)
+static struct socket* dnbd3_connect(dnbd3_device_t *dev, dnbd3_host_t *host);
+
static inline dnbd3_server_t *get_free_alt_server(dnbd3_device_t * const dev)
{
int i;
@@ -167,25 +169,6 @@ static int dnbd3_net_discover(void *data)
struct request *last_request = (struct request *)123, *cur_request = (struct request *)456;
-#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,1,0)
- struct __kernel_sock_timeval timeout;
-#else
- struct timeval timeout;
-#endif
-#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,9,0)
- sockptr_t timeout_ptr;
-#else
- char *timeout_ptr;
-#endif
-
- timeout.tv_sec = SOCKET_TIMEOUT_CLIENT_DISCOVERY;
- timeout.tv_usec = 0;
-#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,9,0)
- timeout_ptr = KERNEL_SOCKPTR(&timeout);
-#else
- timeout_ptr = (char *)&timeout;
-#endif
-
memset(&sin4, 0, sizeof(sin4));
memset(&sin6, 0, sizeof(sin6));
@@ -297,35 +280,11 @@ static int dnbd3_net_discover(void *data)
continue;
// Initialize socket and connect
- if (dnbd3_sock_create(dev->alt_servers[i].host.type, SOCK_STREAM, IPPROTO_TCP, &sock) < 0)
+ sock = dnbd3_connect(dev, &dev->alt_servers[i].host);
+ if (sock == NULL)
{
- dnbd3_dev_err_host_alt(dev, "couldn't create socket (discover)\n");
- sock = NULL;
- continue;
- }
-#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,1,0)
- sock_setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO_NEW, timeout_ptr, sizeof(timeout));
- sock_setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO_NEW, timeout_ptr, sizeof(timeout));
-#else
- sock_setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, timeout_ptr, sizeof(timeout));
- sock_setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, timeout_ptr, sizeof(timeout));
-#endif
- sock->sk->sk_allocation = GFP_NOIO;
- if (dev->alt_servers[i].host.type == HOST_IP4)
- {
- sin4.sin_family = AF_INET;
- memcpy(&sin4.sin_addr, dev->alt_servers[i].host.addr, 4);
- sin4.sin_port = dev->alt_servers[i].host.port;
- if (kernel_connect(sock, (struct sockaddr *)&sin4, sizeof(sin4), 0) < 0)
- goto error;
- }
- else
- {
- sin6.sin6_family = AF_INET6;
- memcpy(&sin6.sin6_addr, dev->alt_servers[i].host.addr, 16);
- sin6.sin6_port = dev->alt_servers[i].host.port;
- if (kernel_connect(sock, (struct sockaddr *)&sin6, sizeof(sin6), 0) < 0)
- goto error;
+ dnbd3_dev_dbg_host_alt(dev, "dnbd3_net_discover: Couldn't connect\n");
+ goto error;
}
// Request filesize
@@ -413,17 +372,27 @@ static int dnbd3_net_discover(void *data)
// panic mode, take first responding server
if (dev->panic)
{
- dev->panic = 0;
dnbd3_dev_dbg_host_alt(dev, "panic mode, changing server ...\n");
- if (best_sock != NULL )
- sock_release(best_sock);
- dev->better_sock = sock; // Pass over socket to take a shortcut in *_connect();
- kfree(buf);
- dev->thread_discover = NULL;
- dnbd3_net_disconnect(dev);
- memcpy(&dev->cur_server, &dev->alt_servers[i], sizeof(dev->cur_server));
- dnbd3_net_connect(dev);
- return 0;
+ while (atomic_cmpxchg(&dev->connection_lock, 0, 1) != 0)
+ {
+ schedule();
+ }
+ if (dev->panic) // Re-check, a connect might have been in progress
+ {
+ dev->panic = 0;
+ if (best_sock != NULL )
+ sock_release(best_sock);
+
+ dev->better_sock = sock; // Pass over socket to take a shortcut in *_connect();
+ kfree(buf);
+ dev->thread_discover = NULL;
+ dnbd3_net_disconnect(dev);
+ memcpy(&dev->cur_server, &dev->alt_servers[i], sizeof(dev->cur_server));
+ dnbd3_net_connect(dev);
+ atomic_set(&dev->connection_lock, 0);
+ return 0;
+ }
+ atomic_set(&dev->connection_lock, 0);
}
// Request block
@@ -506,10 +475,12 @@ static int dnbd3_net_discover(void *data)
continue;
- error: ;
+error: ;
++dev->alt_servers[i].failures;
- sock_release(sock);
- sock = NULL;
+ if (sock != NULL) {
+ sock_release(sock);
+ sock = NULL;
+ }
dev->alt_servers[i].rtts[turn] = RTT_UNREACHABLE;
if (is_same_server(&dev->cur_server, &dev->alt_servers[i]))
{
@@ -557,7 +528,8 @@ static int dnbd3_net_discover(void *data)
}
// take server with lowest rtt
- if (do_change)
+ // if a (dis)connect is already in progress, we do nothing, this is not panic mode
+ if (do_change && atomic_cmpxchg(&dev->connection_lock, 0, 1) == 0)
{
dev_info(dnbd3_device_to_dev(dev), "server %d is faster (%lluµs vs. %lluµs)\n", best_server,
(unsigned long long)best_rtt, (unsigned long long)dev->cur_rtt);
@@ -568,7 +540,7 @@ static int dnbd3_net_discover(void *data)
memcpy(&dev->cur_server, &dev->alt_servers[best_server], sizeof(dev->cur_server));
dev->cur_rtt = best_rtt;
dnbd3_net_connect(dev);
-
+ atomic_set(&dev->connection_lock, 0);
return 0;
}
@@ -658,7 +630,7 @@ static int dnbd3_net_send(void *data)
break;
default:
- if (!dev->disconnecting)
+ if (!atomic_read(&dev->connection_lock))
dev_err(dnbd3_device_to_dev(dev), "unknown command (send %u %u)\n", (int)blk_request->cmd_flags, (int)dnbd3_req_op(blk_request));
list_del_init(&blk_request->queuelist);
spin_unlock_irqrestore(&dev->blk_lock, irqflags);
@@ -673,7 +645,7 @@ static int dnbd3_net_send(void *data)
iov.iov_len = sizeof(dnbd3_request);
if (kernel_sendmsg(dev->sock, &msg, &iov, 1, sizeof(dnbd3_request)) != sizeof(dnbd3_request))
{
- if (!dev->disconnecting)
+ if (!atomic_read(&dev->connection_lock))
dnbd3_dev_err_host_cur(dev, "connection to server lost (send)\n");
ret = -ESHUTDOWN;
goto cleanup;
@@ -687,14 +659,14 @@ static int dnbd3_net_send(void *data)
cleanup:
if (dev->sock)
kernel_sock_shutdown(dev->sock, SHUT_RDWR);
- if (!dev->disconnecting)
+ if (!atomic_read(&dev->connection_lock))
{
dev->panic = 1;
dev->discover = 1;
wake_up(&dev->process_queue_discover);
}
- if (!dev->disconnecting && ret != 0)
+ if (!atomic_read(&dev->connection_lock) && ret != 0)
dev_err(dnbd3_device_to_dev(dev), "kthread dnbd3_net_send terminated abnormally\n");
else
dev_dbg(dnbd3_device_to_dev(dev), "kthread dnbd3_net_send terminated normally (cleanup)\n");
@@ -749,7 +721,7 @@ static int dnbd3_net_receive(void *data)
if (jiffies < recv_timeout) recv_timeout = jiffies; // Handle overflow
if ((jiffies - recv_timeout) / HZ > SOCKET_KEEPALIVE_TIMEOUT)
{
- if (!dev->disconnecting)
+ if (!atomic_read(&dev->connection_lock))
dnbd3_dev_err_host_cur(dev, "receive timeout reached (%d of %d secs)\n", (int)((jiffies - recv_timeout) / HZ), (int)SOCKET_KEEPALIVE_TIMEOUT);
ret = -ETIMEDOUT;
goto cleanup;
@@ -757,7 +729,7 @@ static int dnbd3_net_receive(void *data)
continue;
} else {
/* for all errors other than -EAGAIN, print message and abort thread */
- if (!dev->disconnecting)
+ if (!atomic_read(&dev->connection_lock))
dnbd3_dev_err_host_cur(dev, "connection to server lost (receive)\n");
ret = -ESHUTDOWN;
goto cleanup;
@@ -767,7 +739,7 @@ static int dnbd3_net_receive(void *data)
/* check if arrived data is valid */
if (ret != sizeof(dnbd3_reply))
{
- if (!dev->disconnecting)
+ if (!atomic_read(&dev->connection_lock))
dnbd3_dev_err_host_cur(dev, "recv msg header\n");
ret = -EINVAL;
goto cleanup;
@@ -777,15 +749,13 @@ static int dnbd3_net_receive(void *data)
// check error
if (dnbd3_reply.magic != dnbd3_packet_magic)
{
- if (!dev->disconnecting)
- dnbd3_dev_err_host_cur(dev, "wrong packet magic (receive)\n");
+ dnbd3_dev_err_host_cur(dev, "wrong packet magic (receive)\n");
ret = -EINVAL;
goto cleanup;
}
if (dnbd3_reply.cmd == 0)
{
- if (!dev->disconnecting)
- dnbd3_dev_err_host_cur(dev, "command was 0 (Receive)\n");
+ dnbd3_dev_err_host_cur(dev, "command was 0 (Receive)\n");
ret = -EINVAL;
goto cleanup;
}
@@ -811,9 +781,8 @@ static int dnbd3_net_receive(void *data)
spin_unlock_irqrestore(&dev->blk_lock, irqflags);
if (blk_request == NULL)
{
- if (!dev->disconnecting)
- dnbd3_dev_err_host_cur(dev, "received block data for unrequested handle (%llu: %llu)\n",
- (unsigned long long)dnbd3_reply.handle, (unsigned long long)dnbd3_reply.size);
+ dnbd3_dev_err_host_cur(dev, "received block data for unrequested handle (%llu: %llu)\n",
+ (unsigned long long)dnbd3_reply.handle, (unsigned long long)dnbd3_reply.size);
ret = -EINVAL;
goto cleanup;
}
@@ -837,7 +806,7 @@ static int dnbd3_net_receive(void *data)
}
else
{
- if (!dev->disconnecting)
+ if (!atomic_read(&dev->connection_lock))
dnbd3_dev_err_host_cur(dev, "receiving from net to block layer\n");
ret = -EINVAL;
goto cleanup;
@@ -869,7 +838,7 @@ static int dnbd3_net_receive(void *data)
if (kernel_recvmsg(dev->sock, &msg, &iov, 1, (count * sizeof(dnbd3_server_entry_t)), msg.msg_flags)
!= (count * sizeof(dnbd3_server_entry_t)))
{
- if (!dev->disconnecting)
+ if (!atomic_read(&dev->connection_lock))
dnbd3_dev_err_host_cur(dev, "recv CMD_GET_SERVERS payload\n");
ret = -EINVAL;
goto cleanup;
@@ -888,7 +857,7 @@ static int dnbd3_net_receive(void *data)
ret = kernel_recvmsg(dev->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags);
if (ret <= 0)
{
- if (!dev->disconnecting)
+ if (!atomic_read(&dev->connection_lock))
dnbd3_dev_err_host_cur(dev, "recv additional payload from CMD_GET_SERVERS\n");
ret = -EINVAL;
goto cleanup;
@@ -900,15 +869,14 @@ static int dnbd3_net_receive(void *data)
case CMD_LATEST_RID:
if (dnbd3_reply.size != 2)
{
- if (!dev->disconnecting)
- dev_err(dnbd3_device_to_dev(dev), "CMD_LATEST_RID.size != 2\n");
+ dev_err(dnbd3_device_to_dev(dev), "CMD_LATEST_RID.size != 2\n");
continue;
}
iov.iov_base = &rid;
iov.iov_len = sizeof(rid);
if (kernel_recvmsg(dev->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags) <= 0)
{
- if (!dev->disconnecting)
+ if (!atomic_read(&dev->connection_lock))
dev_err(dnbd3_device_to_dev(dev), "could not receive CMD_LATEST_RID payload\n");
}
else
@@ -922,14 +890,12 @@ static int dnbd3_net_receive(void *data)
case CMD_KEEPALIVE:
if (dnbd3_reply.size != 0)
{
- if (!dev->disconnecting)
- dev_err(dnbd3_device_to_dev(dev), "keep alive packet with payload\n");
+ dev_err(dnbd3_device_to_dev(dev), "keep alive packet with payload\n");
}
continue;
default:
- if (!dev->disconnecting)
- dev_err(dnbd3_device_to_dev(dev), "unknown command (receive)\n");
+ dev_err(dnbd3_device_to_dev(dev), "unknown command (receive)\n");
continue;
}
@@ -942,14 +908,14 @@ static int dnbd3_net_receive(void *data)
cleanup:
if (dev->sock)
kernel_sock_shutdown(dev->sock, SHUT_RDWR);
- if (!dev->disconnecting)
+ if (!atomic_read(&dev->connection_lock))
{
dev->panic = 1;
dev->discover = 1;
wake_up(&dev->process_queue_discover);
}
- if (!dev->disconnecting && ret != 0)
+ if (!atomic_read(&dev->connection_lock) && ret != 0)
dev_err(dnbd3_device_to_dev(dev), "kthread dnbd3_net_receive terminated abnormally\n");
else
dev_dbg(dnbd3_device_to_dev(dev), "kthread dnbd3_net_receive terminated normally (cleanup)\n");
@@ -958,9 +924,10 @@ cleanup:
return ret;
}
-int dnbd3_net_connect(dnbd3_device_t *dev)
+static struct socket* dnbd3_connect(dnbd3_device_t *dev, dnbd3_host_t *host)
{
- struct request *req1 = NULL;
+ int ret;
+ struct socket *sock;
#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,1,0)
struct __kernel_sock_timeval timeout;
#else
@@ -968,26 +935,81 @@ int dnbd3_net_connect(dnbd3_device_t *dev)
#endif
#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,9,0)
sockptr_t timeout_ptr;
+ timeout_ptr = KERNEL_SOCKPTR(&timeout);
#else
char *timeout_ptr;
+ timeout_ptr = (char *)&timeout;
#endif
- if (dev->disconnecting)
+ timeout.tv_sec = SOCKET_TIMEOUT_CLIENT_DATA;
+ timeout.tv_usec = 0;
+
+ if (dnbd3_sock_create(host->type, SOCK_STREAM, IPPROTO_TCP, &sock) < 0)
{
- dnbd3_dev_dbg_host_cur(dev, "connect: wait for disconnect has finished ...\n");
- set_current_state(TASK_INTERRUPTIBLE);
- while (dev->disconnecting)
- schedule();
- dnbd3_dev_dbg_host_cur(dev, "connect: disconnect is done\n");
+ dev_err(dnbd3_device_to_dev(dev), "couldn't create socket\n");
+ return NULL;
}
- timeout.tv_sec = SOCKET_TIMEOUT_CLIENT_DATA;
- timeout.tv_usec = 0;
-#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,9,0)
- timeout_ptr = KERNEL_SOCKPTR(&timeout);
+#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,1,0)
+ sock_setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO_NEW, timeout_ptr, sizeof(timeout));
+ sock_setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO_NEW, timeout_ptr, sizeof(timeout));
#else
- timeout_ptr = (char *)&timeout;
+ sock_setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, timeout_ptr, sizeof(timeout));
+ sock_setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, timeout_ptr, sizeof(timeout));
#endif
+ sock->sk->sk_allocation = GFP_NOIO;
+ if (host->type == HOST_IP4)
+ {
+ struct sockaddr_in sin;
+ memset(&sin, 0, sizeof(sin));
+ sin.sin_family = AF_INET;
+ memcpy(&(sin.sin_addr), host->addr, 4);
+ sin.sin_port = host->port;
+ ret = kernel_connect(sock, (struct sockaddr *)&sin, sizeof(sin), O_NONBLOCK);
+ if (ret != 0 && ret != -EINPROGRESS)
+ {
+ dev_err(dnbd3_device_to_dev(dev), "connection to host failed (v4)\n");
+ goto error;
+ }
+ }
+ else
+ {
+ struct sockaddr_in6 sin;
+ memset(&sin, 0, sizeof(sin));
+ sin.sin6_family = AF_INET6;
+ memcpy(&(sin.sin6_addr), host->addr, 16);
+ sin.sin6_port = host->port;
+ ret = kernel_connect(sock, (struct sockaddr *)&sin, sizeof(sin), O_NONBLOCK);
+ if (ret != 0 && ret != -EINPROGRESS)
+ {
+ dev_err(dnbd3_device_to_dev(dev), "connection to host failed (v6)\n");
+ goto error;
+ }
+ }
+ if (ret != 0) {
+ // XXX How can we do a connect with short timeout? This is dumb
+ ktime_t start = ktime_get_real();
+ while (ktime_ms_delta(ktime_get_real(), start) < SOCKET_TIMEOUT_CLIENT_DATA * 1000) {
+ struct sockaddr_storage addr;
+ ret = kernel_getpeername(sock, (struct sockaddr*)&addr);
+ if (ret >= 0)
+ break;
+ msleep(1);
+ }
+ if (ret < 0) {
+ dev_dbg(dnbd3_device_to_dev(dev), "connect timed out (%d)\n", ret);
+ goto error;
+ }
+ }
+ return sock;
+error:
+ sock_release(sock);
+ return NULL;
+}
+
+int dnbd3_net_connect(dnbd3_device_t *dev)
+{
+ struct request *req1 = NULL;
// do some checks before connecting
req1 = kmalloc(sizeof(*req1), GFP_ATOMIC);
@@ -1029,46 +1051,13 @@ int dnbd3_net_connect(dnbd3_device_t *dev)
int mlen;
init_msghdr(msg);
- if (dnbd3_sock_create(dev->cur_server.host.type, SOCK_STREAM, IPPROTO_TCP, &dev->sock) < 0)
+ dev->sock = dnbd3_connect(dev, &dev->cur_server.host);
+ if (dev->sock == NULL)
{
- dnbd3_dev_err_host_cur(dev, "couldn't create socket (v6)\n");
+ dnbd3_dev_err_host_cur(dev, "dnbd3_net_connect: Failed\n");
goto error;
}
-#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,1,0)
- sock_setsockopt(dev->sock, SOL_SOCKET, SO_SNDTIMEO_NEW, timeout_ptr, sizeof(timeout));
- sock_setsockopt(dev->sock, SOL_SOCKET, SO_RCVTIMEO_NEW, timeout_ptr, sizeof(timeout));
-#else
- sock_setsockopt(dev->sock, SOL_SOCKET, SO_SNDTIMEO, timeout_ptr, sizeof(timeout));
- sock_setsockopt(dev->sock, SOL_SOCKET, SO_RCVTIMEO, timeout_ptr, sizeof(timeout));
-#endif
- dev->sock->sk->sk_allocation = GFP_NOIO;
- if (dev->cur_server.host.type == HOST_IP4)
- {
- struct sockaddr_in sin;
- memset(&sin, 0, sizeof(sin));
- sin.sin_family = AF_INET;
- memcpy(&(sin.sin_addr), dev->cur_server.host.addr, 4);
- sin.sin_port = dev->cur_server.host.port;
- if (kernel_connect(dev->sock, (struct sockaddr *)&sin, sizeof(sin), 0) != 0)
- {
- dnbd3_dev_err_host_cur(dev, "connection to host failed (v4)\n");
- goto error;
- }
- }
- else
- {
- struct sockaddr_in6 sin;
- memset(&sin, 0, sizeof(sin));
- sin.sin6_family = AF_INET6;
- memcpy(&(sin.sin6_addr), dev->cur_server.host.addr, 16);
- sin.sin6_port = dev->cur_server.host.port;
- if (kernel_connect(dev->sock, (struct sockaddr *)&sin, sizeof(sin), 0) != 0)
- {
- dnbd3_dev_err_host_cur(dev, "connection to host failed (v6)\n");
- }
- }
-
// Request filesize
dnbd3_request.magic = dnbd3_packet_magic;
dnbd3_request.cmd = CMD_SELECT_IMAGE;
@@ -1159,28 +1148,10 @@ int dnbd3_net_connect(dnbd3_device_t *dev)
dnbd3_dev_dbg_host_cur(dev, "on-the-fly server change ...\n");
dev->sock = dev->better_sock;
dev->better_sock = NULL;
-#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,1,0)
- sock_setsockopt(dev->sock, SOL_SOCKET, SO_SNDTIMEO_NEW, timeout_ptr, sizeof(timeout));
- sock_setsockopt(dev->sock, SOL_SOCKET, SO_RCVTIMEO_NEW, timeout_ptr, sizeof(timeout));
-#else
- sock_setsockopt(dev->sock, SOL_SOCKET, SO_SNDTIMEO, timeout_ptr, sizeof(timeout));
- sock_setsockopt(dev->sock, SOL_SOCKET, SO_RCVTIMEO, timeout_ptr, sizeof(timeout));
-#endif
}
- dev->panic = 0;
- dev->panic_count = 0;
-
- // Enqueue request to request_queue_send for a fresh list of alt servers
- dnbd3_cmd_to_priv(req1, CMD_GET_SERVERS);
- list_add(&req1->queuelist, &dev->request_queue_send);
-
// create required threads
dev->thread_send = kthread_create(dnbd3_net_send, dev, "%s-send", dev->disk->disk_name);
- dev->thread_receive = kthread_create(dnbd3_net_receive, dev, "%s-receive", dev->disk->disk_name);
- dev->thread_discover = kthread_create(dnbd3_net_discover, dev, "%s-discover", dev->disk->disk_name);
-
- // start them up
if (!IS_ERR(dev->thread_send)) {
get_task_struct(dev->thread_send);
wake_up_process(dev->thread_send);
@@ -1188,9 +1159,10 @@ int dnbd3_net_connect(dnbd3_device_t *dev)
dev_err(dnbd3_device_to_dev(dev), "failed to create send thread\n");
/* reset error to cleanup thread */
dev->thread_send = NULL;
- goto cleanup_thread;
+ goto error;
}
+ dev->thread_receive = kthread_create(dnbd3_net_receive, dev, "%s-receive", dev->disk->disk_name);
if (!IS_ERR(dev->thread_receive)) {
get_task_struct(dev->thread_receive);
wake_up_process(dev->thread_receive);
@@ -1198,9 +1170,10 @@ int dnbd3_net_connect(dnbd3_device_t *dev)
dev_err(dnbd3_device_to_dev(dev), "failed to create receive thread\n");
/* reset error to cleanup thread */
dev->thread_receive = NULL;
- goto cleanup_thread;
+ goto error;
}
+ dev->thread_discover = kthread_create(dnbd3_net_discover, dev, "%s-discover", dev->disk->disk_name);
if (!IS_ERR(dev->thread_discover)) {
get_task_struct(dev->thread_discover);
wake_up_process(dev->thread_discover);
@@ -1208,12 +1181,21 @@ int dnbd3_net_connect(dnbd3_device_t *dev)
dev_err(dnbd3_device_to_dev(dev), "failed to create discover thread\n");
/* reset error to cleanup thread */
dev->thread_discover = NULL;
- goto cleanup_thread;
+ goto error;
}
+ dev->panic = 0;
+ dev->panic_count = 0;
+
+ // Enqueue request to request_queue_send for a fresh list of alt servers
+ dnbd3_cmd_to_priv(req1, CMD_GET_SERVERS);
+ list_add(&req1->queuelist, &dev->request_queue_send);
+
wake_up(&dev->process_queue_send);
// add heartbeat timer
+ // Do not goto error after creating the timer - we require that the timer exists
+ // if dev->sock != NULL -- see dnbd3_net_disconnect
dev->heartbeat_count = 0;
timer_setup(&dev->hb_timer, dnbd3_net_heartbeat, 0);
dev->hb_timer.expires = jiffies + HZ;
@@ -1221,9 +1203,6 @@ int dnbd3_net_connect(dnbd3_device_t *dev)
return 0;
-cleanup_thread:
- dnbd3_net_disconnect(dev);
-
error:
if (dev->sock)
{
@@ -1244,22 +1223,16 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev)
bool thread_not_terminated = false;
int ret = 0;
- if (dev->disconnecting) {
- ret = -EBUSY;
- goto out;
- }
-
- dev->disconnecting = 1;
-
dev_dbg(dnbd3_device_to_dev(dev), "disconnecting device ...\n");
- // clear heartbeat timer
- del_timer(&dev->hb_timer);
dev->discover = 0;
- if (dev->sock)
+ if (dev->sock) {
kernel_sock_shutdown(dev->sock, SHUT_RDWR);
+ // clear heartbeat timer
+ del_timer(&dev->hb_timer);
+ }
// kill sending and receiving threads
if (dev->thread_send)
@@ -1278,6 +1251,7 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev)
thread_not_terminated = true;
}
}
+ dev->thread_send = NULL;
}
if (dev->thread_receive)
@@ -1296,6 +1270,7 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev)
thread_not_terminated = true;
}
}
+ dev->thread_receive = NULL;
}
if (dev->thread_discover)
@@ -1314,6 +1289,7 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev)
thread_not_terminated = true;
}
}
+ dev->thread_discover = NULL;
}
if (dev->sock)
@@ -1332,8 +1308,5 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev)
ret = 0;
}
- dev->disconnecting = 0;
-
-out:
return ret;
}