From ef92307fd49e75482c7599caf68685afc1807512 Mon Sep 17 00:00:00 2001 From: Manuel Bentele Date: Fri, 6 Nov 2020 12:44:43 +0100 Subject: [KERNEL, CLIENT]: submit and probe multiple dnbd3-server with ioctl OPEN The ioctl OPEN call for DNBD3 devices exposed by the dnbd3 Linux kernel module, is extended with a fixed array of dnbd3 hosts. The fixed array allows the dnbd3-client to submit host information (IP address and port) of multiple dnbd3-servers. This information is used to probe all submitted dnbd3-servers and add them to the alternative dnbd3-server list. If at least one dnbd3-server is not reachable, the OPEN ioctl call will abort with an error code. --- inc/dnbd3/types.h | 6 +- src/client/client.c | 68 +++++++++++--- src/kernel/blk.c | 54 ++++++++--- src/kernel/dnbd3_main.h | 1 - src/kernel/net.c | 244 +++++++++++++++++++++++++++++++++--------------- 5 files changed, 268 insertions(+), 105 deletions(-) diff --git a/inc/dnbd3/types.h b/inc/dnbd3/types.h index 63e182c..59bf2d1 100644 --- a/inc/dnbd3/types.h +++ b/inc/dnbd3/types.h @@ -128,10 +128,14 @@ typedef struct __attribute__((packed)) dnbd3_host_t dnbd3_af type; // 1byte (ip version. HOST_IP4 or HOST_IP6. 0 means this struct is empty and should be ignored) } dnbd3_host_t; +/* IOCTLs */ +#define MAX_HOSTS_PER_IOCTL NUMBER_SERVERS + typedef struct __attribute__((packed)) { uint16_t len; - dnbd3_host_t host; + dnbd3_host_t hosts[MAX_HOSTS_PER_IOCTL]; + uint8_t hosts_num; uint16_t imgnamelen; char *imgname; int rid; diff --git a/src/client/client.c b/src/client/client.c index c309551..fd2770f 100644 --- a/src/client/client.c +++ b/src/client/client.c @@ -41,7 +41,7 @@ #define SOCK_BUFFER 1000 #define DEV_LEN 15 #define MAX_DEVS 50 - +#define TMP_STR_LEN 100 static int openDevices[MAX_DEVS]; static const char *optString = "f:h:i:r:d:a:cs:SA:R:HV?k"; @@ -194,6 +194,42 @@ static int dnbd3_get_ip(char *hostname, dnbd3_host_t *host) return true; } +/* parses hosts from space separated cmdln string, resolves them and saves them into hosts */ +static int dnbd3_get_resolved_hosts(char *hosts_str, dnbd3_host_t *hosts, const size_t hosts_len) +{ + char *hosts_current_token = hosts_str; + char *hosts_last_host; + int hosts_index = 0; + char host_str[TMP_STR_LEN]; + size_t host_str_len = 0; + + do { + /* get next host from string */ + while (*hosts_current_token == ' ') { + hosts_current_token++; + } + + /* buffer substring of host to get ip from it */ + hosts_last_host = strchr(hosts_current_token, ' '); + host_str_len = (hosts_last_host == NULL ? TMP_STR_LEN : (size_t)(hosts_last_host - hosts_current_token) + 1); + if ( host_str_len > TMP_STR_LEN ) + host_str_len = TMP_STR_LEN; + + snprintf(host_str, host_str_len, "%s", hosts_current_token); + + if (!dnbd3_get_ip(host_str, &hosts[hosts_index])) + return false; + + hosts_index++; + + /* continue processing of hosts */ + hosts_current_token = hosts_last_host + 1; + + } while ( hosts_last_host != NULL && hosts_index < hosts_len ); + + return hosts_index; +} + int main(int argc, char *argv[]) { char *dev = NULL; @@ -205,9 +241,8 @@ int main(int argc, char *argv[]) dnbd3_ioctl_t msg; memset( &msg, 0, sizeof(dnbd3_ioctl_t) ); msg.len = (uint16_t)sizeof(dnbd3_ioctl_t); + msg.hosts_num = 0; msg.read_ahead_kb = DEFAULT_READ_AHEAD_KB; - msg.host.port = htons( PORT ); - msg.host.type = 0; msg.imgname = NULL; int opt = 0; @@ -220,7 +255,9 @@ int main(int argc, char *argv[]) case 'f': break; case 'h': - if ( !dnbd3_get_ip( optarg, &msg.host ) ) exit( EXIT_FAILURE ); + msg.hosts_num = dnbd3_get_resolved_hosts(optarg, msg.hosts, MAX_HOSTS_PER_IOCTL); + if (!msg.hosts_num) + exit( EXIT_FAILURE ); break; case 'i': action = IOCTL_OPEN; @@ -240,18 +277,21 @@ int main(int argc, char *argv[]) action = IOCTL_CLOSE; break; case 's': - dnbd3_get_ip( optarg, &msg.host ); + dnbd3_get_ip( optarg, &msg.hosts[0] ); + msg.hosts_num = 1; action = IOCTL_SWITCH; break; case 'S': learnNewServers = false; break; case 'A': - dnbd3_get_ip( optarg, &msg.host ); + dnbd3_get_ip( optarg, &msg.hosts[0] ); + msg.hosts_num = 1; action = IOCTL_ADD_SRV; break; case 'R': - dnbd3_get_ip( optarg, &msg.host ); + dnbd3_get_ip( optarg, &msg.hosts[0] ); + msg.hosts_num = 1; action = IOCTL_REM_SRV; break; case 'H': @@ -295,10 +335,10 @@ int main(int argc, char *argv[]) setuid( getuid() ); } - host_to_string( &msg.host, host, 50 ); + host_to_string( &msg.hosts[0], host, 50 ); // close device - if ( action == IOCTL_CLOSE && msg.host.type == 0 && dev && (msg.imgname == NULL )) { + if ( action == IOCTL_CLOSE && msg.hosts_num == 0 && dev && (msg.imgname == NULL )) { printf( "INFO: Closing device %s\n", dev ); if ( dnbd3_ioctl( dev, IOCTL_CLOSE, &msg ) ) exit( EXIT_SUCCESS ); printf( "Couldn't close device.\n" ); @@ -306,7 +346,7 @@ int main(int argc, char *argv[]) } // switch host - if ( (action == IOCTL_SWITCH || action == IOCTL_ADD_SRV || action == IOCTL_REM_SRV) && msg.host.type != 0 && dev && (msg.imgname == NULL )) { + if ( (action == IOCTL_SWITCH || action == IOCTL_ADD_SRV || action == IOCTL_REM_SRV) && msg.hosts_num == 1 && dev && (msg.imgname == NULL )) { if ( action == IOCTL_SWITCH ) printf( "INFO: Switching device %s to %s\n", dev, host ); if ( action == IOCTL_ADD_SRV ) printf( "INFO: %s: adding %s\n", dev, host ); if ( action == IOCTL_REM_SRV ) printf( "INFO: %s: removing %s\n", dev, host ); @@ -316,7 +356,7 @@ int main(int argc, char *argv[]) } // connect - if ( action == IOCTL_OPEN && msg.host.type != 0 && dev && (msg.imgname != NULL )) { + if ( action == IOCTL_OPEN && msg.hosts_num > 0 && dev && (msg.imgname != NULL )) { printf( "INFO: Connecting device %s to %s for image %s\n", dev, host, msg.imgname ); if ( dnbd3_ioctl( dev, IOCTL_OPEN, &msg ) ) exit( EXIT_SUCCESS ); printf( "ERROR: connecting device failed. Maybe it's already connected?\n" ); @@ -530,7 +570,7 @@ static int dnbd3_daemon_ioctl(int uid, char *device, int action, const char *act memset( &msg, 0, sizeof(msg) ); msg.len = (uint16_t)sizeof(msg); if ( host != NULL ) { - dnbd3_get_ip( host, &msg.host ); + dnbd3_get_ip( host, &msg.hosts[0] ); } if ( index < 0 || index >= MAX_DEVS ) { printf( "%s request with invalid device id %d\n", actionName, index ); @@ -578,7 +618,7 @@ static char* dnbd3_daemon_open(int uid, char *host, char *image, int rid, int re // Open dnbd3_ioctl_t msg; msg.len = (uint16_t)sizeof(msg); - if ( !dnbd3_get_ip( host, &msg.host ) ) { + if ( !dnbd3_get_ip( host, &msg.hosts[0] ) ) { printf( "Cannot parse host address %s\n", host ); return NULL ; } @@ -665,7 +705,7 @@ static void dnbd3_print_help(char *argv_0) printf( " -h -i [-r ] -d [-a ] || -c -d \n\n" ); printf( "Start the DNBD3 client.\n\n" ); //printf("-f or --file \t\t Configuration file (default /etc/dnbd3-client.conf)\n"); - printf( "-h or --host \t\t Host running dnbd3-server.\n" ); + printf( "-h or --host \t\t List of space separated hosts to use.\n" ); printf( "-i or --image \t\t Image name of exported image.\n" ); printf( "-r or --rid \t\t Release-ID of exported image (default 0, latest).\n" ); printf( "-d or --device \t\t DNBD3 device name.\n" ); diff --git a/src/kernel/blk.c b/src/kernel/blk.c index 00c3f8f..2ff322a 100644 --- a/src/kernel/blk.c +++ b/src/kernel/blk.c @@ -39,6 +39,7 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int struct request_queue *blk_queue = dev->disk->queue; char *imgname = NULL; dnbd3_ioctl_t *msg = NULL; + int i = 0; while (dev->disconnecting) { /* do nothing */ } @@ -91,30 +92,58 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int } else { - if (sizeof(msg->host) != sizeof(dev->cur_server.host)) - dev_info(dnbd3_device_to_dev(dev), "odd size bug triggered in IOCTL\n"); - memcpy(&dev->cur_server.host, &msg->host, sizeof(msg->host)); - dev->cur_server.failures = 0; - memcpy(&dev->initial_server, &dev->cur_server, sizeof(dev->initial_server)); + if (sizeof(msg->hosts[0]) != sizeof(dev->cur_server.host)) + dev_warn(dnbd3_device_to_dev(dev), "odd size bug triggered in IOCTL\n"); + + /* assert that at least one and not to many hosts are given */ + if (msg->hosts_num < 1 || msg->hosts_num > NUMBER_SERVERS) { + result = -EINVAL; + break; + } + dev->imgname = imgname; dev->rid = msg->rid; dev->use_server_provided_alts = msg->use_server_provided_alts; - // Forget all alt servers on explicit connect, set first al server to initial server - memset(dev->alt_servers, 0, sizeof(dev->alt_servers[0])*NUMBER_SERVERS); - memcpy(dev->alt_servers, &dev->initial_server, sizeof(dev->alt_servers[0])); + if (blk_queue->backing_dev_info != NULL) { blk_queue->backing_dev_info->ra_pages = (msg->read_ahead_kb * 1024) / PAGE_SIZE; } - if (dnbd3_net_connect(dev) == 0) - { + /* probe and add specified servers */ + /* copy and probe servers in reverse order, so that the first specified server will be remain as inital/current server */ + for (i = msg->hosts_num - 1; i >= 0; i--) { + /* copy provided host into corresponding alt server slot */ + memset(&dev->alt_servers[i], 0, sizeof(dev->alt_servers[i])); + memcpy(&dev->alt_servers[i].host, &msg->hosts[i], sizeof(msg->hosts[i])); + dev->alt_servers[i].failures = 0; + /* probe added alt server */ + memcpy(&dev->cur_server, &dev->alt_servers[i], sizeof(dev->cur_server)); + memcpy(&dev->initial_server, &dev->cur_server, sizeof(dev->initial_server)); + if (dnbd3_net_connect(dev) != 0) { + /* probing server failed, abort IOCTL with error */ + result = -ENOENT; + break; + } + + /* probing server was successful, go on with other servers */ result = 0; + /* do not disconnect last server since this is the current/initial server that should be connected */ + if (i > 0) { + dnbd3_blk_fail_all_requests(dev); + result = dnbd3_net_disconnect(dev); + dnbd3_blk_fail_all_requests(dev); + } + } + + if (result == 0) + { + /* probing was successful */ imgname = NULL; // Prevent kfree at the end } else { - result = -ENOENT; + /* probing failed */ dev->imgname = NULL; } } @@ -154,7 +183,7 @@ static int dnbd3_blk_ioctl(struct block_device *bdev, fmode_t mode, unsigned int } else { - memcpy(&dev->new_servers[dev->new_servers_num].host, &msg->host, sizeof(msg->host)); + memcpy(&dev->new_servers[dev->new_servers_num].host, &msg->hosts[0], sizeof(msg->hosts[0])); dev->new_servers[dev->new_servers_num].failures = (cmd == IOCTL_ADD_SRV ? 0 : 1); // 0 = ADD, 1 = REM ++dev->new_servers_num; result = 0; @@ -231,7 +260,6 @@ int dnbd3_blk_add_device(dnbd3_device_t *dev, int minor) int ret; init_waitqueue_head(&dev->process_queue_send); - init_waitqueue_head(&dev->process_queue_receive); init_waitqueue_head(&dev->process_queue_discover); INIT_LIST_HEAD(&dev->request_queue_send); INIT_LIST_HEAD(&dev->request_queue_receive); diff --git a/src/kernel/dnbd3_main.h b/src/kernel/dnbd3_main.h index 124426b..8be77de 100644 --- a/src/kernel/dnbd3_main.h +++ b/src/kernel/dnbd3_main.h @@ -76,7 +76,6 @@ typedef struct struct task_struct *thread_discover; struct timer_list hb_timer; wait_queue_head_t process_queue_send; - wait_queue_head_t process_queue_receive; wait_queue_head_t process_queue_discover; struct list_head request_queue_send; struct list_head request_queue_receive; diff --git a/src/kernel/net.c b/src/kernel/net.c index 46c369a..57d8cc7 100644 --- a/src/kernel/net.c +++ b/src/kernel/net.c @@ -27,7 +27,6 @@ #include #include -#include #ifndef MIN #define MIN(a,b) ((a) < (b) ? (a) : (b)) @@ -230,7 +229,7 @@ static int dnbd3_net_discover(void *data) check_order[i] = i; } - for (;;) + while (!kthread_should_stop()) { wait_event_interruptible(dev->process_queue_discover, kthread_should_stop() || dev->discover || dev->thread_discover == NULL); @@ -246,9 +245,9 @@ static int dnbd3_net_discover(void *data) continue; // Check if the list of alt servers needs to be updated and do so if necessary + spin_lock_irqsave(&dev->blk_lock, irqflags); if (dev->new_servers_num) { - spin_lock_irqsave(&dev->blk_lock, irqflags); for (i = 0; i < dev->new_servers_num; ++i) { if (dev->new_servers[i].host.type != HOST_IP4 && dev->new_servers[i].host.type != HOST_IP6) // Invalid entry? @@ -286,8 +285,8 @@ static int dnbd3_net_discover(void *data) alt_server->failures = 0; } dev->new_servers_num = 0; - spin_unlock_irqrestore(&dev->blk_lock, irqflags); } + spin_unlock_irqrestore(&dev->blk_lock, irqflags); current_server = best_server = -1; best_rtt = 0xFFFFFFFul; @@ -611,8 +610,9 @@ static int dnbd3_net_discover(void *data) } - dev_dbg(dnbd3_device_to_dev(dev), "kthread dnbd3_net_discover terminated normally\n"); kfree(buf); + dev_dbg(dnbd3_device_to_dev(dev), "kthread dnbd3_net_discover terminated normally\n"); + dev->thread_discover = NULL; return 0; } @@ -634,19 +634,19 @@ static int dnbd3_net_send(void *data) set_user_nice(current, -20); // move already sent requests to request_queue_send again - while (!list_empty(&dev->request_queue_receive)) + spin_lock_irqsave(&dev->blk_lock, irqflags); + if (!list_empty(&dev->request_queue_receive)) { dev_warn(dnbd3_device_to_dev(dev), "request queue was not empty"); - spin_lock_irqsave(&dev->blk_lock, irqflags); list_for_each_entry_safe(blk_request, tmp_request, &dev->request_queue_receive, queuelist) { list_del_init(&blk_request->queuelist); list_add(&blk_request->queuelist, &dev->request_queue_send); } - spin_unlock_irqrestore(&dev->blk_lock, irqflags); } + spin_unlock_irqrestore(&dev->blk_lock, irqflags); - for (;;) + while (!kthread_should_stop()) { wait_event_interruptible(dev->process_queue_send, kthread_should_stop() || !list_empty(&dev->request_queue_send)); @@ -654,6 +654,7 @@ static int dnbd3_net_send(void *data) break; // extract block request + /* lock since we aquire a blk request from the request_queue_send */ spin_lock_irqsave(&dev->blk_lock, irqflags); if (list_empty(&dev->request_queue_send)) { @@ -661,7 +662,6 @@ static int dnbd3_net_send(void *data) continue; } blk_request = list_entry(dev->request_queue_send.next, struct request, queuelist); - spin_unlock_irqrestore(&dev->blk_lock, irqflags); // what to do? switch (dnbd3_req_op(blk_request)) @@ -671,22 +671,17 @@ static int dnbd3_net_send(void *data) dnbd3_request.offset = blk_rq_pos(blk_request) << 9; // *512 dnbd3_request.size = blk_rq_bytes(blk_request); // bytes left to complete entire request // enqueue request to request_queue_receive - spin_lock_irqsave(&dev->blk_lock, irqflags); list_del_init(&blk_request->queuelist); list_add_tail(&blk_request->queuelist, &dev->request_queue_receive); - spin_unlock_irqrestore(&dev->blk_lock, irqflags); break; case DNBD3_REQ_OP_SPECIAL: dnbd3_request.cmd = dnbd3_priv_to_cmd(blk_request); dnbd3_request.size = 0; - spin_lock_irqsave(&dev->blk_lock, irqflags); list_del_init(&blk_request->queuelist); - spin_unlock_irqrestore(&dev->blk_lock, irqflags); break; default: dev_err(dnbd3_device_to_dev(dev), "unknown command (send %u %u)\n", (int)blk_request->cmd_flags, (int)dnbd3_req_op(blk_request)); - spin_lock_irqsave(&dev->blk_lock, irqflags); list_del_init(&blk_request->queuelist); spin_unlock_irqrestore(&dev->blk_lock, irqflags); continue; @@ -694,6 +689,7 @@ static int dnbd3_net_send(void *data) // send net request dnbd3_request.handle = (uint64_t)(uintptr_t)blk_request; // Double cast to prevent warning on 32bit + spin_unlock_irqrestore(&dev->blk_lock, irqflags); fixup_request(dnbd3_request); iov.iov_base = &dnbd3_request; iov.iov_len = sizeof(dnbd3_request); @@ -702,14 +698,13 @@ static int dnbd3_net_send(void *data) dnbd3_dev_err_host_cur(dev, "connection to server lost (send)\n"); goto error; } - wake_up(&dev->process_queue_receive); } dev_dbg(dnbd3_device_to_dev(dev), "kthread dnbd3_net_send terminated normally\n"); dev->thread_send = NULL; return 0; - error: ; +error: if (dev->sock) kernel_sock_shutdown(dev->sock, SHUT_RDWR); if (!dev->disconnecting) @@ -718,6 +713,7 @@ static int dnbd3_net_send(void *data) dev->discover = 1; wake_up(&dev->process_queue_discover); } + dev_err(dnbd3_device_to_dev(dev), "kthread dnbd3_net_send terminated abnormally\n"); dev->thread_send = NULL; return -1; } @@ -735,11 +731,10 @@ static int dnbd3_net_receive(void *data) struct bio_vec *bvec = &bvec_inst; void *kaddr; unsigned long irqflags; - sigset_t blocked, oldset; uint16_t rid; unsigned long int recv_timeout = jiffies; - int count, remaining, ret; + int count, remaining, ret = 0; init_msghdr(msg); set_user_nice(current, -20); @@ -750,25 +745,44 @@ static int dnbd3_net_receive(void *data) iov.iov_base = &dnbd3_reply; iov.iov_len = sizeof(dnbd3_reply); ret = kernel_recvmsg(dev->sock, &msg, &iov, 1, sizeof(dnbd3_reply), msg.msg_flags); - if (ret == -EAGAIN) + + /* end thread after socket timeout or reception of data */ + if (kthread_should_stop()) + break; + + /* check return value of kernel_recvmsg() */ + if (ret == 0) { - if (jiffies < recv_timeout) recv_timeout = jiffies; // Handle overflow - if ((jiffies - recv_timeout) / HZ > SOCKET_KEEPALIVE_TIMEOUT) - { - dnbd3_dev_err_host_cur(dev, "receive timeout reached (%d of %d secs)\n", (int)((jiffies - recv_timeout) / HZ), (int)SOCKET_KEEPALIVE_TIMEOUT); - goto error; - } - continue; + /* have not received any data, but remote peer is shutdown properly */ + dnbd3_dev_dbg_host_cur(dev, "remote peer has performed an orderly shutdown\n"); + goto cleanup; } - if (ret <= 0) + else if (ret < 0) { - dnbd3_dev_err_host_cur(dev, "connection to server lost (receive)\n"); - goto error; + if (ret == -EAGAIN) + { + if (jiffies < recv_timeout) recv_timeout = jiffies; // Handle overflow + if ((jiffies - recv_timeout) / HZ > SOCKET_KEEPALIVE_TIMEOUT) + { + dnbd3_dev_err_host_cur(dev, "receive timeout reached (%d of %d secs)\n", (int)((jiffies - recv_timeout) / HZ), (int)SOCKET_KEEPALIVE_TIMEOUT); + ret = -ETIMEDOUT; + goto cleanup; + } + continue; + } else { + /* for all errors other than -EAGAIN, print message and abort thread */ + dnbd3_dev_err_host_cur(dev, "connection to server lost (receive)\n"); + ret = -ESHUTDOWN; + goto cleanup; + } } + + /* check if arrived data is valid */ if (ret != sizeof(dnbd3_reply)) { dnbd3_dev_err_host_cur(dev, "recv msg header\n"); - goto error; + ret = -EINVAL; + goto cleanup; } fixup_reply(dnbd3_reply); @@ -776,12 +790,14 @@ static int dnbd3_net_receive(void *data) if (dnbd3_reply.magic != dnbd3_packet_magic) { dnbd3_dev_err_host_cur(dev, "wrong packet magic (receive)\n"); - goto error; + ret = -EINVAL; + goto cleanup; } if (dnbd3_reply.cmd == 0) { dnbd3_dev_err_host_cur(dev, "command was 0 (Receive)\n"); - goto error; + ret = -EINVAL; + goto cleanup; } // Update timeout @@ -807,27 +823,23 @@ static int dnbd3_net_receive(void *data) { dnbd3_dev_err_host_cur(dev, "received block data for unrequested handle (%llu: %llu)\n", (unsigned long long)dnbd3_reply.handle, (unsigned long long)dnbd3_reply.size); - goto error; + ret = -EINVAL; + goto cleanup; } // receive data and answer to block layer rq_for_each_segment(bvec_inst, blk_request, iter) { - siginitsetinv(&blocked, sigmask(SIGKILL)); - sigprocmask(SIG_SETMASK, &blocked, &oldset); - kaddr = kmap(bvec->bv_page) + bvec->bv_offset; iov.iov_base = kaddr; iov.iov_len = bvec->bv_len; if (kernel_recvmsg(dev->sock, &msg, &iov, 1, bvec->bv_len, msg.msg_flags) != bvec->bv_len) { kunmap(bvec->bv_page); - sigprocmask(SIG_SETMASK, &oldset, NULL ); dnbd3_dev_err_host_cur(dev, "receiving from net to block layer\n"); - goto error; + ret = -EINVAL; + goto cleanup; } kunmap(bvec->bv_page); - - sigprocmask(SIG_SETMASK, &oldset, NULL ); } spin_lock_irqsave(&dev->blk_lock, irqflags); list_del_init(&blk_request->queuelist); @@ -854,7 +866,8 @@ static int dnbd3_net_receive(void *data) != (count * sizeof(dnbd3_server_entry_t))) { dnbd3_dev_err_host_cur(dev, "recv CMD_GET_SERVERS payload\n"); - goto error; + ret = -EINVAL; + goto cleanup; } spin_lock_irqsave(&dev->blk_lock, irqflags); dev->new_servers_num = count; @@ -871,7 +884,8 @@ static int dnbd3_net_receive(void *data) if (ret <= 0) { dnbd3_dev_err_host_cur(dev, "recv additional payload from CMD_GET_SERVERS\n"); - goto error; + ret = -EINVAL; + goto cleanup; } remaining -= ret; } @@ -909,11 +923,9 @@ static int dnbd3_net_receive(void *data) } } - dev_dbg(dnbd3_device_to_dev(dev), "kthread dnbd3_net_receive terminated normally\n"); - dev->thread_receive = NULL; - return 0; + goto out; -error: +cleanup: if (dev->sock) kernel_sock_shutdown(dev->sock, SHUT_RDWR); if (!dev->disconnecting) @@ -922,8 +934,14 @@ error: dev->discover = 1; wake_up(&dev->process_queue_discover); } + +out: + if (!ret) + dev_dbg(dnbd3_device_to_dev(dev), "kthread dnbd3_net_receive terminated normally\n"); + else + dev_err(dnbd3_device_to_dev(dev), "kthread dnbd3_net_receive terminated abnormally\n"); dev->thread_receive = NULL; - return -1; + return ret; } int dnbd3_net_connect(dnbd3_device_t *dev) @@ -940,20 +958,13 @@ int dnbd3_net_connect(dnbd3_device_t *dev) char *timeout_ptr; #endif - if (dev->disconnecting) { - dnbd3_dev_dbg_host_cur(dev, "CONNECT: still disconnecting!\n"); + if (dev->disconnecting) + { + dnbd3_dev_dbg_host_cur(dev, "connect: wait for disconnect has finished ...\n"); + set_current_state(TASK_INTERRUPTIBLE); while (dev->disconnecting) schedule(); - } - if (dev->thread_receive != NULL) { - dnbd3_dev_dbg_host_cur(dev, "CONNECT: still receiving!\n"); - while (dev->thread_receive != NULL) - schedule(); - } - if (dev->thread_send != NULL) { - dnbd3_dev_dbg_host_cur(dev, "CONNECT: still sending!\n"); - while (dev->thread_send != NULL) - schedule(); + dnbd3_dev_dbg_host_cur(dev, "connect: disconnect is done\n"); } timeout.tv_sec = SOCKET_TIMEOUT_CLIENT_DATA; @@ -1151,13 +1162,40 @@ int dnbd3_net_connect(dnbd3_device_t *dev) list_add(&req1->queuelist, &dev->request_queue_send); // create required threads - dev->thread_send = kthread_create(dnbd3_net_send, dev, dev->disk->disk_name); - dev->thread_receive = kthread_create(dnbd3_net_receive, dev, dev->disk->disk_name); - dev->thread_discover = kthread_create(dnbd3_net_discover, dev, dev->disk->disk_name); + dev->thread_send = kthread_create(dnbd3_net_send, dev, "%s-send", dev->disk->disk_name); + dev->thread_receive = kthread_create(dnbd3_net_receive, dev, "%s-receive", dev->disk->disk_name); + dev->thread_discover = kthread_create(dnbd3_net_discover, dev, "%s-discover", dev->disk->disk_name); + // start them up - wake_up_process(dev->thread_send); - wake_up_process(dev->thread_receive); - wake_up_process(dev->thread_discover); + if (!IS_ERR(dev->thread_send)) { + get_task_struct(dev->thread_send); + wake_up_process(dev->thread_send); + } else { + dev_err(dnbd3_device_to_dev(dev), "failed to create send thread\n"); + /* reset error to cleanup thread */ + dev->thread_send = NULL; + goto cleanup_thread; + } + + if (!IS_ERR(dev->thread_receive)) { + get_task_struct(dev->thread_receive); + wake_up_process(dev->thread_receive); + } else { + dev_err(dnbd3_device_to_dev(dev), "failed to create receive thread\n"); + /* reset error to cleanup thread */ + dev->thread_receive = NULL; + goto cleanup_thread; + } + + if (!IS_ERR(dev->thread_discover)) { + get_task_struct(dev->thread_discover); + wake_up_process(dev->thread_discover); + } else { + dev_err(dnbd3_device_to_dev(dev), "failed to create discover thread\n"); + /* reset error to cleanup thread */ + dev->thread_discover = NULL; + goto cleanup_thread; + } wake_up(&dev->process_queue_send); @@ -1169,6 +1207,9 @@ int dnbd3_net_connect(dnbd3_device_t *dev) return 0; +cleanup_thread: + dnbd3_net_disconnect(dev); + error: if (dev->sock) { @@ -1185,14 +1226,19 @@ error: int dnbd3_net_disconnect(dnbd3_device_t *dev) { - if (dev->disconnecting) - return 0; + struct task_struct* thread = NULL; + bool thread_not_terminated = false; + int ret = 0; - if (dev->cur_server.host.port) - dnbd3_dev_dbg_host_cur(dev, "disconnecting device\n"); + if (dev->disconnecting) { + ret = -EBUSY; + goto out; + } dev->disconnecting = 1; + dev_dbg(dnbd3_device_to_dev(dev), "disconnecting device ...\n"); + // clear heartbeat timer del_timer(&dev->hb_timer); @@ -1204,21 +1250,58 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev) // kill sending and receiving threads if (dev->thread_send) { - kthread_stop(dev->thread_send); + dnbd3_dev_dbg_host_cur(dev, "stop send thread\n"); + thread = dev->thread_send; + ret = kthread_stop(thread); + put_task_struct(thread); + if (ret == -EINTR) { + /* thread has never been scheduled and run */ + dev_dbg(dnbd3_device_to_dev(dev), "send thread has never run\n"); + } else { + /* thread has run, check if it has terminated successfully */ + if (dev->thread_send != NULL) { + dev_err(dnbd3_device_to_dev(dev), "send thread was not terminated correctly\n"); + thread_not_terminated = true; + } + } } if (dev->thread_receive) { - kthread_stop(dev->thread_receive); + dnbd3_dev_dbg_host_cur(dev, "stop receive thread\n"); + thread = dev->thread_receive; + ret = kthread_stop(thread); + put_task_struct(thread); + if (ret == -EINTR) { + /* thread has never been scheduled and run */ + dev_dbg(dnbd3_device_to_dev(dev), "receive thread has never run\n"); + } else { + /* thread has run, check if it has terminated successfully */ + if (dev->thread_receive != NULL) { + dev_err(dnbd3_device_to_dev(dev), "receive thread was not terminated correctly\n"); + thread_not_terminated = true; + } + } } if (dev->thread_discover) { - kthread_stop(dev->thread_discover); - dev->thread_discover = NULL; + dnbd3_dev_dbg_host_cur(dev, "stop discover thread\n"); + thread = dev->thread_discover; + ret = kthread_stop(thread); + put_task_struct(thread); + if (ret == -EINTR) { + /* thread has never been scheduled and run */ + dev_dbg(dnbd3_device_to_dev(dev), "discover thread has never run\n"); + } else { + /* thread has run, check if it has terminated successfully */ + if (dev->thread_discover != NULL) { + dev_err(dnbd3_device_to_dev(dev), "discover thread was not terminated correctly\n"); + thread_not_terminated = true; + } + } } - // clear socket if (dev->sock) { sock_release(dev->sock); @@ -1227,7 +1310,16 @@ int dnbd3_net_disconnect(dnbd3_device_t *dev) dev->cur_server.host.type = 0; dev->cur_server.host.port = 0; + if (thread_not_terminated) { + dev_err(dnbd3_device_to_dev(dev), "failed to disconnect device\n"); + ret = -ENODEV; + } else { + dev_dbg(dnbd3_device_to_dev(dev), "device is disconnected\n"); + ret = 0; + } + dev->disconnecting = 0; - return 0; +out: + return ret; } -- cgit v1.2.3-55-g7522