From a1caec1f2bebe09f685716f13d7b55f84d8a8145 Mon Sep 17 00:00:00 2001 From: Simon Rettberg Date: Thu, 19 Nov 2020 13:48:14 +0100 Subject: [KERNEL] Fix several connect/disconnect race conditions Previously disconnect was protected against concurrent calls, but connect wasn't. It was easy to crash the kernel when calling connect and disconnect IOCTLs in a tight loop concurrently. A global lock was introduced to make sure only one caller can change the connection state at a time. dev->connection_lock needs to be aquired when calling dnbd3_net_connect or _disconnect. This atomic_t based locking mechanism should be turned into a mutex in a next step, relying on mutex_trylock for cases where we don't have the cmpxchg-schedule() loop. Along the way it was noticed that the send/receive timeouts don't apply to kernel_connect, which might have been the case in older 3.x kernel versions. A crude workaround using nonblocking connect has been introduced to emulate this, but a clean solution for this is welcomed. Also, devices are now properly closed on module unload. --- src/kernel/blk.c | 57 +++++++-- src/kernel/dnbd3_main.h | 3 +- src/kernel/net.c | 321 ++++++++++++++++++++++-------------------------- 3 files changed, 193 insertions(+), 188 deletions(-) diff --git a/src/kernel/blk.c b/src/kernel/blk.c index 69d02d5..43251d3 100644 --- a/src/kernel/blk.c +++ b/src/kernel/blk.c @@ -33,6 +33,25 @@ #define dnbd3_req_special(req) \ blk_rq_is_private(req) +static int dnbd3_close_device(dnbd3_device_t *dev) +{ + int result; + dnbd3_blk_fail_all_requests(dev); + dev->panic = 0; + dev->discover = 0; + result = dnbd3_net_disconnect(dev); + dnbd3_blk_fail_all_requests(dev); + blk_mq_freeze_queue(dev->queue); + set_capacity(dev->disk, 0); + blk_mq_unfreeze_queue(dev->queue); + if (dev->imgname) + { + kfree(dev->imgname); + dev->imgname = NULL; + } + return result; +} + static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int cmd, unsigned long arg) { int result = -100; @@ -45,8 +64,7 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int dnbd3_server_t *alt_server; unsigned long irqflags; int i = 0; - - while (dev->disconnecting) { /* do nothing */ } + u8 locked = 0; if (arg != 0) { @@ -83,6 +101,12 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int switch (cmd) { case IOCTL_OPEN: + if (atomic_cmpxchg(&dev->connection_lock, 0, 1) != 0) + { + result = -EBUSY; + break; + } + locked = 1; if (dev->imgname != NULL) { result = -EBUSY; @@ -165,20 +189,22 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int break; case IOCTL_CLOSE: - dnbd3_blk_fail_all_requests(dev); - result = dnbd3_net_disconnect(dev); - dnbd3_blk_fail_all_requests(dev); - blk_mq_freeze_queue(dev->queue); - set_capacity(dev->disk, 0); - blk_mq_unfreeze_queue(dev->queue); - if (dev->imgname) + if (atomic_cmpxchg(&dev->connection_lock, 0, 1) != 0) { - kfree(dev->imgname); - dev->imgname = NULL; + result = -EBUSY; + break; } + locked = 1; + result = dnbd3_close_device(dev); break; case IOCTL_SWITCH: + if (atomic_cmpxchg(&dev->connection_lock, 0, 1) != 0) + { + result = -EBUSY; + break; + } + locked = 1; if (dev->imgname == NULL) { result = -ENOENT; @@ -278,6 +304,9 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int break; } + if (locked) + atomic_set(&dev->connection_lock, 0); + cleanup_return: if (msg) kfree(msg); if (imgname) kfree(imgname); @@ -354,7 +383,7 @@ int dnbd3_blk_add_device(dnbd3_device_t *dev, int minor) dev->thread_receive = NULL; dev->thread_discover = NULL; dev->discover = 0; - dev->disconnecting = 0; + atomic_set(&dev->connection_lock, 0); dev->panic = 0; dev->panic_count = 0; dev->reported_size = 0; @@ -432,8 +461,10 @@ out: int dnbd3_blk_del_device(dnbd3_device_t *dev) { + while (atomic_cmpxchg(&dev->connection_lock, 0, 1) != 0) + schedule(); + dnbd3_close_device(dev); dnbd3_sysfs_exit(dev); - dnbd3_net_disconnect(dev); del_gendisk(dev->disk); blk_cleanup_queue(dev->queue); blk_mq_free_tag_set(&dev->tag_set); diff --git a/src/kernel/dnbd3_main.h b/src/kernel/dnbd3_main.h index a3c2828..ec8c8cf 100644 --- a/src/kernel/dnbd3_main.h +++ b/src/kernel/dnbd3_main.h @@ -62,7 +62,8 @@ typedef struct dnbd3_server_t alt_servers[NUMBER_SERVERS]; // array of alt servers int new_servers_num; // number of new alt servers that are waiting to be copied to above array dnbd3_server_entry_t new_servers[NUMBER_SERVERS]; // pending new alt servers - uint8_t discover, panic, disconnecting, update_available, panic_count; + uint8_t discover, panic, update_available, panic_count; + atomic_t connection_lock; uint8_t use_server_provided_alts; uint16_t rid; uint32_t heartbeat_count; diff --git a/src/kernel/net.c b/src/kernel/net.c index a0444d2..f460458 100644 --- a/src/kernel/net.c +++ b/src/kernel/net.c @@ -74,6 +74,8 @@ #define dnbd3_dev_dbg_host_alt(dev, fmt, ...) __dnbd3_dev_dbg_host((dev), (dev)->alt_servers[i].host, fmt, ##__VA_ARGS__) #define dnbd3_dev_err_host_alt(dev, fmt, ...) __dnbd3_dev_err_host((dev), (dev)->alt_servers[i].host, fmt, ##__VA_ARGS__) +static struct socket* dnbd3_connect(dnbd3_device_t *dev, dnbd3_host_t *host); + static inline dnbd3_server_t *get_free_alt_server(dnbd3_device_t * const dev) { int i; @@ -167,25 +169,6 @@ static int dnbd3_net_discover(void *data) struct request *last_request = (struct request *)123, *cur_request = (struct request *)456; -#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,1,0) - struct __kernel_sock_timeval timeout; -#else - struct timeval timeout; -#endif -#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,9,0) - sockptr_t timeout_ptr; -#else - char *timeout_ptr; -#endif - - timeout.tv_sec = SOCKET_TIMEOUT_CLIENT_DISCOVERY; - timeout.tv_usec = 0; -#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,9,0) - timeout_ptr = KERNEL_SOCKPTR(&timeout); -#else - timeout_ptr = (char *)&timeout; -#endif - memset(&sin4, 0, sizeof(sin4)); memset(&sin6, 0, sizeof(sin6)); @@ -297,35 +280,11 @@ static int dnbd3_net_discover(void *data) continue; // Initialize socket and connect - if (dnbd3_sock_create(dev->alt_servers[i].host.type, SOCK_STREAM, IPPROTO_TCP, &sock) < 0) + sock = dnbd3_connect(dev, &dev->alt_servers[i].host); + if (sock == NULL) { - dnbd3_dev_err_host_alt(dev, "couldn't create socket (discover)\n"); - sock = NULL; - continue; - } -#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,1,0) - sock_setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO_NEW, timeout_ptr, sizeof(timeout)); - sock_setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO_NEW, timeout_ptr, sizeof(timeout)); -#else - sock_setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, timeout_ptr, sizeof(timeout)); - sock_setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, timeout_ptr, sizeof(timeout)); -#endif - sock->sk->sk_allocation = GFP_NOIO; - if (dev->alt_servers[i].host.type == HOST_IP4) - { - sin4.sin_family = AF_INET; - memcpy(&sin4.sin_addr, dev->alt_servers[i].host.addr, 4); - sin4.sin_port = dev->alt_servers[i].host.port; - if (kernel_connect(sock, (struct sockaddr *)&sin4, sizeof(sin4), 0) < 0) - goto error; - } - else - { - sin6.sin6_family = AF_INET6; - memcpy(&sin6.sin6_addr, dev->alt_servers[i].host.addr, 16); - sin6.sin6_port = dev->alt_servers[i].host.port; - if (kernel_connect(sock, (struct sockaddr *)&sin6, sizeof(sin6), 0) < 0) - goto error; + dnbd3_dev_dbg_host_alt(dev, "dnbd3_net_discover: Couldn't connect\n"); + goto error; } // Request filesize @@ -413,17 +372,27 @@ static int dnbd3_net_discover(void *data) // panic mode, take first responding server if (dev->panic) { - dev->panic = 0; dnbd3_dev_dbg_host_alt(dev, "panic mode, changing server ...\n"); - if (best_sock != NULL ) - sock_release(best_sock); - dev->better_sock = sock; // Pass over socket to take a shortcut in *_connect(); - kfree(buf); - dev->thread_discover = NULL; - dnbd3_net_disconnect(dev); - memcpy(&dev->cur_server, &dev->alt_servers[i], sizeof(dev->cur_server)); - dnbd3_net_connect(dev); - return 0; + while (atomic_cmpxchg(&dev->connection_lock, 0, 1) != 0) + { + schedule(); + } + if (dev->panic) // Re-check, a connect might have been in progress + { + dev->panic = 0; + if (best_sock != NULL ) + sock_release(best_sock); + + dev->better_sock = sock; // Pass over socket to take a shortcut in *_connect(); + kfree(buf); + dev->thread_discover = NULL; + dnbd3_net_disconnect(dev); + memcpy(&dev->cur_server, &dev->alt_servers[i], sizeof(dev->cur_server)); + dnbd3_net_connect(dev); + atomic_set(&dev->connection_lock, 0); + return 0; + } + atomic_set(&dev->connection_lock, 0); } // Request block @@ -506,10 +475,12 @@ static int dnbd3_net_discover(void *data) continue; - error: ; +error: ; ++dev->alt_servers[i].failures; - sock_release(sock); - sock = NULL; + if (sock != NULL) { + sock_release(sock); + sock = NULL; + } dev->alt_servers[i].rtts[turn] = RTT_UNREACHABLE; if (is_same_server(&dev->cur_server, &dev->alt_servers[i])) { @@ -557,7 +528,8 @@ static int dnbd3_net_discover(void *data) } // take server with lowest rtt - if (do_change) + // if a (dis)connect is already in progress, we do nothing, this is not panic mode + if (do_change && atomic_cmpxchg(&dev->connection_lock, 0, 1) == 0) { dev_info(dnbd3_device_to_dev(dev), "server %d is faster (%lluµs vs. %lluµs)\n", best_server, (unsigned long long)best_rtt, (unsigned long long)dev->cur_rtt); @@ -568,7 +540,7 @@ static int dnbd3_net_discover(void *data) memcpy(&dev->cur_server, &dev->alt_servers[best_server], sizeof(dev->cur_server)); dev->cur_rtt = best_rtt; dnbd3_net_connect(dev); - + atomic_set(&dev->connection_lock, 0); return 0; } @@ -658,7 +630,7 @@ static int dnbd3_net_send(void *data) break; default: - if (!dev->disconnecting) + if (!atomic_read(&dev->connection_lock)) dev_err(dnbd3_device_to_dev(dev), "unknown command (send %u %u)\n", (int)blk_request->cmd_flags, (int)dnbd3_req_op(blk_request)); list_del_init(&blk_request->queuelist); spin_unlock_irqrestore(&dev->blk_lock, irqflags); @@ -673,7 +645,7 @@ static int dnbd3_net_send(void *data) iov.iov_len = sizeof(dnbd3_request); if (kernel_sendmsg(dev->sock, &msg, &iov, 1, sizeof(dnbd3_request)) != sizeof(dnbd3_request)) { - if (!dev->disconnecting) + if (!atomic_read(&dev->connection_lock)) dnbd3_dev_err_host_cur(dev, "connection to server lost (send)\n"); ret = -ESHUTDOWN; goto cleanup; @@ -687,14 +659,14 @@ static int dnbd3_net_send(void *data) cleanup: if (dev->sock) kernel_sock_shutdown(dev->sock, SHUT_RDWR); - if (!dev->disconnecting) + if (!atomic_read(&dev->connection_lock)) { dev->panic = 1; dev->discover = 1; wake_up(&dev->process_queue_discover); } - if (!dev->disconnecting && ret != 0) + if (!atomic_read(&dev->connection_lock) && ret != 0) dev_err(dnbd3_device_to_dev(dev), "kthread dnbd3_net_send terminated abnormally\n"); else dev_dbg(dnbd3_device_to_dev(dev), "kthread dnbd3_net_send terminated normally (cleanup)\n"); @@ -749,7 +721,7 @@ static int dnbd3_net_receive(void *data) if (jiffies < recv_timeout) recv_timeout = jiffies; // Handle overflow if ((jiffies - recv_timeout) / HZ > SOCKET_KEEPALIVE_TIMEOUT) { - if (!dev->disconnecting) + if (!atomic_read(&dev->connection_lock)) dnbd3_dev_err_host_cur(dev, "receive timeout reached (%d of %d secs)\n", (int)((jiffies - recv_timeout) / HZ), (int)SOCKET_KEEPALIVE_TIMEOUT); ret = -ETIMEDOUT; goto cleanup; @@ -757,7 +729,7 @@ static int dnbd3_net_receive(void *data) continue; } else { /* for all errors other than -EAGAIN, print message and abort thread */ - if (!dev->disconnecting) + if (!atomic_read(&dev->connection_lock)) dnbd3_dev_err_host_cur(dev, "connection to server lost (receive)\n"); ret = -ESHUTDOWN; goto cleanup; @@ -767,7 +739,7 @@ static int dnbd3_net_receive(void *data) /* check if arrived data is valid */ if (ret != sizeof(dnbd3_reply)) { - if (!dev->disconnecting) + if (!atomic_read(&dev->connection_lock)) dnbd3_dev_err_host_cur(dev, "recv msg header\n"); ret = -EINVAL; goto cleanup; @@ -777,15 +749,13 @@ static int dnbd3_net_receive(void *data) // check error if (dnbd3_reply.magic != dnbd3_packet_magic) { - if (!dev->disconnecting) - dnbd3_dev_err_host_cur(dev, "wrong packet magic (receive)\n"); + dnbd3_dev_err_host_cur(dev, "wrong packet magic (receive)\n"); ret = -EINVAL; goto cleanup; } if (dnbd3_reply.cmd == 0) { - if (!dev->disconnecting) - dnbd3_dev_err_host_cur(dev, "command was 0 (Receive)\n"); + dnbd3_dev_err_host_cur(dev, "command was 0 (Receive)\n"); ret = -EINVAL; goto cleanup; } @@ -811,9 +781,8 @@ static int dnbd3_net_receive(void *data) spin_unlock_irqrestore(&dev->blk_lock, irqflags); if (blk_request == NULL) { - if (!dev->disconnecting) - dnbd3_dev_err_host_cur(dev, "received block data for unrequested handle (%llu: %llu)\n", - (unsigned long long)dnbd3_reply.handle, (unsigned long long)dnbd3_reply.size); + dnbd3_dev_err_host_cur(dev, "received block data for unrequested handle (%llu: %llu)\n", + (unsigned long long)dnbd3_reply.handle, (unsigned long long)dnbd3_reply.size); ret = -EINVAL; goto cleanup; } @@ -837,7 +806,7 @@ static int dnbd3_net_receive(void *data) } else { - if (!dev->disconnecting) + if (!atomic_read(&dev->connection_lock)) dnbd3_dev_err_host_cur(dev, "receiving from net to block layer\n"); ret = -EINVAL; goto cleanup; @@ -869,7 +838,7 @@ static int dnbd3_net_receive(void *data) if (kernel_recvmsg(dev->sock, &msg, &iov, 1, (count * sizeof(dnbd3_server_entry_t)), msg.msg_flags) != (count * sizeof(dnbd3_server_entry_t))) { - if (!dev->disconnecting) + if (!atomic_read(&dev->connection_lock)) dnbd3_dev_err_host_cur(dev, "recv CMD_GET_SERVERS payload\n"); ret = -EINVAL; goto cleanup; @@ -888,7 +857,7 @@ static int dnbd3_net_receive(void *data) ret = kernel_recvmsg(dev->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags); if (ret <= 0) { - if (!dev->disconnecting) + if (!atomic_read(&dev->connection_lock)) dnbd3_dev_err_host_cur(dev, "recv additional payload from CMD_GET_SERVERS\n"); ret = -EINVAL; goto cleanup; @@ -900,15 +869,14 @@ static int dnbd3_net_receive(void *data) case CMD_LATEST_RID: if (dnbd3_reply.size != 2) { - if (!dev->disconnecting) - dev_err(dnbd3_device_to_dev(dev), "CMD_LATEST_RID.size != 2\n"); + dev_err(dnbd3_device_to_dev(dev), "CMD_LATEST_RID.size != 2\n"); continue; } iov.iov_base = &rid; iov.iov_len = sizeof(rid); if (kernel_recvmsg(dev->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags) <= 0) { - if (!dev->disconnecting) + if (!atomic_read(&dev->connection_lock)) dev_err(dnbd3_device_to_dev(dev), "could not receive CMD_LATEST_RID payload\n"); } else @@ -922,14 +890,12 @@ static int dnbd3_net_receive(void *data) case CMD_KEEPALIVE: if (dnbd3_reply.size != 0) { - if (!dev->disconnecting) - dev_err(dnbd3_device_to_dev(dev), "keep alive packet with payload\n"); + dev_err(dnbd3_device_to_dev(dev), "keep alive packet with payload\n"); } continue; default: - if (!dev->disconnecting) - dev_err(dnbd3_device_to_dev(dev), "unknown command (receive)\n"); + dev_err(dnbd3_device_to_dev(dev), "unknown command (receive)\n"); continue; } @@ -942,14 +908,14 @@ static int dnbd3_net_receive(void *data) cleanup: if (dev->sock) kernel_sock_shutdown(dev->sock, SHUT_RDWR); - if (!dev->disconnecting) + if (!atomic_read(&dev->connection_lock)) { dev->panic = 1; dev->discover = 1; wake_up(&dev->process_queue_discover); } - if (!dev->disconnecting && ret != 0) + if (!atomic_read(&dev->connection_lock) && ret != 0) dev_err(dnbd3_device_to_dev(dev), "kthread dnbd3_net_receive terminated abnormally\n"); else dev_dbg(dnbd3_device_to_dev(dev), "kthread dnbd3_net_receive terminated normally (cleanup)\n"); @@ -958,9 +924,10 @@ cleanup: return ret; } -int dnbd3_net_connect(dnbd3_device_t *dev) +static struct socket* dnbd3_connect(dnbd3_device_t *dev, dnbd3_host_t *host) { - struct request *req1 = NULL; + int ret; + struct socket *sock; #if LINUX_VERSION_CODE >= KERNEL_VERSION(5,1,0) struct __kernel_sock_timeval timeout; #else @@ -968,26 +935,81 @@ int dnbd3_net_connect(dnbd3_device_t *dev) #endif #if LINUX_VERSION_CODE >= KERNEL_VERSION(5,9,0) sockptr_t timeout_ptr; + timeout_ptr = KERNEL_SOCKPTR(&timeout); #else char *timeout_ptr; + timeout_ptr = (char *)&timeout; #endif - if (dev->disconnecting) + timeout.tv_sec = SOCKET_TIMEOUT_CLIENT_DATA; + timeout.tv_usec = 0; + + if (dnbd3_sock_create(host->type, SOCK_STREAM, IPPROTO_TCP, &sock) < 0) { - dnbd3_dev_dbg_host_cur(dev, "connect: wait for disconnect has finished ...\n"); - set_current_state(TASK_INTERRUPTIBLE); - while (dev->disconnecting) - schedule(); - dnbd3_dev_dbg_host_cur(dev, "connect: disconnect is done\n"); + dev_err(dnbd3_device_to_dev(dev), "couldn't create socket\n"); + return NULL; } - timeout.tv_sec = SOCKET_TIMEOUT_CLIENT_DATA; - timeout.tv_usec = 0; -#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,9,0) - timeout_ptr = KERNEL_SOCKPTR(&timeout); +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,1,0) + sock_setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO_NEW, timeout_ptr, sizeof(timeout)); + sock_setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO_NEW, timeout_ptr, sizeof(timeout)); #else - timeout_ptr = (char *)&timeout; + sock_setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, timeout_ptr, sizeof(timeout)); + sock_setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, timeout_ptr, sizeof(timeout)); #endif + sock->sk->sk_allocation = GFP_NOIO; + if (host->type == HOST_IP4) + { + struct sockaddr_in sin; + memset(&sin, 0, sizeof(sin)); + sin.sin_family = AF_INET; + memcpy(&(sin.sin_addr), host->addr, 4); + sin.sin_port = host->port; + ret = kernel_connect(sock, (struct sockaddr *)&sin, sizeof(sin), O_NONBLOCK); + if (ret != 0 && ret != -EINPROGRESS) + { + dev_err(dnbd3_device_to_dev(dev), "connection to host failed (v4)\n"); + goto error; + } + } + else + { + struct sockaddr_in6 sin; + memset(&sin, 0, sizeof(sin)); + sin.sin6_family = AF_INET6; + memcpy(&(sin.sin6_addr), host->addr, 16); + sin.sin6_port = host->port; + ret = kernel_connect(sock, (struct sockaddr *)&sin, sizeof(sin), O_NONBLOCK); + if (ret != 0 && ret != -EINPROGRESS) + { + dev_err(dnbd3_device_to_dev(dev), "connection to host failed (v6)\n"); + goto error; + } + } + if (ret != 0) { + // XXX How can we do a connect with short timeout? This is dumb + ktime_t start = ktime_get_real(); + while (ktime_ms_delta(ktime_get_real(), start) < SOCKET_TIMEOUT_CLIENT_DATA * 1000) { + struct sockaddr_storage addr; + ret = kernel_getpeername(sock, (struct sockaddr*)&addr); + if (ret >= 0) + break; + msleep(1); + } + if (ret < 0) { + dev_dbg(dnbd3_device_to_dev(dev), "connect timed out (%d)\n", ret); + goto error; + } + } + return sock; +error: + sock_release(sock); + return NULL; +} + +int dnbd3_net_connect(dnbd3_device_t *dev) +{ + struct request *req1 = NULL; // do some checks before connecting req1 = kmalloc(sizeof(*req1), GFP_ATOMIC); @@ -1029,46 +1051,13 @@ int dnbd3_net_connect(dnbd3_device_t *dev) int mlen; init_msghdr(msg); - if (dnbd3_sock_create(dev->cur_server.host.type, SOCK_STREAM, IPPROTO_TCP, &dev->sock) < 0) + dev->sock = dnbd3_connect(dev, &dev->cur_server.host); + if (dev->sock == NULL) { - dnbd3_dev_err_host_cur(dev, "couldn't create socket (v6)\n"); + dnbd3_dev_err_host_cur(dev, "dnbd3_net_connect: Failed\n"); goto error; } -#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,1,0) - sock_setsockopt(dev->sock, SOL_SOCKET, SO_SNDTIMEO_NEW, timeout_ptr, sizeof(timeout)); - sock_setsockopt(dev->sock, SOL_SOCKET, SO_RCVTIMEO_NEW, timeout_ptr, sizeof(timeout)); -#else - sock_setsockopt(dev->sock, SOL_SOCKET, SO_SNDTIMEO, timeout_ptr, sizeof(timeout)); - sock_setsockopt(dev->sock, SOL_SOCKET, SO_RCVTIMEO, timeout_ptr, sizeof(timeout)); -#endif - dev->sock->sk->sk_allocation = GFP_NOIO; - if (dev->cur_server.host.type == HOST_IP4) - { - struct sockaddr_in sin; - memset(&sin, 0, sizeof(sin)); - sin.sin_family = AF_INET; - memcpy(&(sin.sin_addr), dev->cur_server.host.addr, 4); - sin.sin_port = dev->cur_server.host.port; - if (kernel_connect(dev->sock, (struct sockaddr *)&sin, sizeof(sin), 0) != 0) - { - dnbd3_dev_err_host_cur(dev, "connection to host failed (v4)\n"); - goto error; - } - } - else - { - struct sockaddr_in6 sin; - memset(&sin, 0, sizeof(sin)); - sin.sin6_family = AF_INET6; - memcpy(&(sin.sin6_addr), dev->cur_server.host.addr, 16); - sin.sin6_port = dev->cur_server.host.port; - if (kernel_connect(dev->sock, (struct sockaddr *)&sin, sizeof(sin), 0) != 0) - { - dnbd3_dev_err_host_cur(dev, "connection to host failed (v6)\n"); - } - } - // Request filesize dnbd3_request.magic = dnbd3_packet_magic; dnbd3_request.cmd = CMD_SELECT_IMAGE; @@ -1159,28 +1148,10 @@ int dnbd3_net_connect(dnbd3_device_t *dev) dnbd3_dev_dbg_host_cur(dev, "on-the-fly server change ...\n"); dev->sock = dev->better_sock; dev->better_sock = NULL; -#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,1,0) - sock_setsockopt(dev->sock, SOL_SOCKET, SO_SNDTIMEO_NEW, timeout_ptr, sizeof(timeout)); - sock_setsockopt(dev->sock, SOL_SOCKET, SO_RCVTIMEO_NEW, timeout_ptr, sizeof(timeout)); -#else - sock_setsockopt(dev->sock, SOL_SOCKET, SO_SNDTIMEO, timeout_ptr, sizeof(timeout)); - sock_setsockopt(dev->sock, SOL_SOCKET, SO_RCVTIMEO, timeout_ptr, sizeof(timeout)); -#endif } - dev->panic = 0; - dev->panic_count = 0; - - // Enqueue request to request_queue_send for a fresh list of alt servers - dnbd3_cmd_to_priv(req1, CMD_GET_SERVERS); - list_add(&req1->queuelist, &dev->request_queue_send); - // create required threads dev->thread_send = kthread_create(dnbd3_net_send, dev, "%s-send", dev->disk->disk_name); - dev->thread_receive = kthread_create(dnbd3_net_receive, dev, "%s-receive", dev->disk->disk_name); - dev->thread_discover = kthread_create(dnbd3_net_discover, dev, "%s-discover", dev->disk->disk_name); - - // start them up if (!IS_ERR(dev->thread_send)) { get_task_struct(dev->thread_send); wake_up_process(dev->thread_send); @@ -1188,9 +1159,10 @@ int dnbd3_net_connect(dnbd3_device_t *dev) dev_err(dnbd3_device_to_dev(dev), "failed to create send thread\n"); /* reset error to cleanup thread */ dev->thread_send = NULL; - goto cleanup_thread; + goto error; } + dev->thread_receive = kthread_create(dnbd3_net_receive, dev, "%s-receive", dev->disk->disk_name); if (!IS_ERR(dev->thread_receive)) { get_task_struct(dev->thread_receive); wake_up_process(dev->thread_receive); @@ -1198,9 +1170,10 @@ int dnbd3_net_connect(dnbd3_device_t *dev) dev_err(dnbd3_device_to_dev(dev), "failed to create receive thread\n"); /* reset error to cleanup thread */ dev->thread_receive = NULL; - goto cleanup_thread; + goto error; } + dev->thread_discover = kthread_create(dnbd3_net_discover, dev, "%s-discover", dev->disk->disk_name); if (!IS_ERR(dev->thread_discover)) { get_task_struct(dev->thread_discover); wake_up_process(dev->thread_discover); @@ -1208,12 +1181,21 @@ int dnbd3_net_connect(dnbd3_device_t *dev) dev_err(dnbd3_device_to_dev(dev), "failed to create discover thread\n"); /* reset error to cleanup thread */ dev->thread_discover = NULL; - goto cleanup_thread; + goto error; } + dev->panic = 0; + dev->panic_count = 0; + + // Enqueue request to request_queue_send for a fresh list of alt servers + dnbd3_cmd_to_priv(req1, CMD_GET_SERVERS); + list_add(&req1->queuelist, &dev->request_queue_send); + wake_up(&dev->process_queue_send); // add heartbeat timer + // Do not goto error after creating the timer - we require that the timer exists + // if dev->sock != NULL -- see dnbd3_net_disconnect dev->heartbeat_count = 0; timer_setup(&dev->hb_timer, dnbd3_net_heartbeat, 0); dev->hb_timer.expires = jiffies + HZ; @@ -1221,9 +1203,6 @@ int dnbd3_net_connect(dnbd3_device_t *dev) return 0; -cleanup_thread: - dnbd3_net_disconnect(dev); - error: if (dev->sock) { @@ -1244,22 +1223,16 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev) bool thread_not_terminated = false; int ret = 0; - if (dev->disconnecting) { - ret = -EBUSY; - goto out; - } - - dev->disconnecting = 1; - dev_dbg(dnbd3_device_to_dev(dev), "disconnecting device ...\n"); - // clear heartbeat timer - del_timer(&dev->hb_timer); dev->discover = 0; - if (dev->sock) + if (dev->sock) { kernel_sock_shutdown(dev->sock, SHUT_RDWR); + // clear heartbeat timer + del_timer(&dev->hb_timer); + } // kill sending and receiving threads if (dev->thread_send) @@ -1278,6 +1251,7 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev) thread_not_terminated = true; } } + dev->thread_send = NULL; } if (dev->thread_receive) @@ -1296,6 +1270,7 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev) thread_not_terminated = true; } } + dev->thread_receive = NULL; } if (dev->thread_discover) @@ -1314,6 +1289,7 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev) thread_not_terminated = true; } } + dev->thread_discover = NULL; } if (dev->sock) @@ -1332,8 +1308,5 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev) ret = 0; } - dev->disconnecting = 0; - -out: return ret; } -- cgit v1.2.3-55-g7522