summaryrefslogblamecommitdiffstats
path: root/src/kernel/net.c
blob: 07218035eed79d08668760c3babdffdd2b59c824 (plain) (tree)























                                                               
                       



                         





                                                                                                                                 
                                                                         

                                                                                                                                   

                                                                          








                                                   

                                        
 

                                                                                              
 
                                                            







                                                                                       
                                                       

              
                                                                   

                                                            
                                                                                       



                 












                                                                                           

                                      
                          


                           


                        
                                           





                                                 





                                                                                                

                                                                        


                                                           
                                  
                                                                                                     


                                                                         

                                                                              




                                                                                                     


                            









                                                                                   
 
                                                                                                   
                                     


                                                                                          
                                                                                              
                                                                      


                           





                             
 








                                                                           
         
 










                                                                                               


                           



                                                     

                           

























































































                                                                                                                        

 


                                                                            
 









                                                                                                                             
                                                      
 
                                                                         











































                                                                                                                                                                                 

                                           




                                                                                                         

                                           


















                                                                                                                                            


                                           
 






                                                                                                               
 



                                                                                                                         

                                           
 










                                                                                                                                                  
 




                                                                                  
      











                                                                                            





                                                                                        
                                               






                                                                                   
                            
                                                                

                                                           
                         





                                                                                   
                                
                                  




                                                                                           
                                              







                                                                                      
                            
                 
                                                                           
                                         
                                                                












                                                                                     
 


                                                                                 
 
 


                                                                                                     


                                                       
















































                                                                                                                                                 
 


                                                                          

                                
 











                                                                         




























                                                                                                  


































                                                                                                 





                                                           
 


                                                      
 

                                                                     

                                                                                        
                           
 
         


                                   


                                                                                                             








                                                                
                   





                                         

                           




                                  

                                                                                             
 












                                                                                            




                                                                              












                                                            
                            







                                                  

                                                  
                                                                                 










                                                 


                                                                   





                                                                       


                                                      











                                                                                
/*
 * This file is part of the Distributed Network Block Device 3
 *
 * Copyright(c) 2019 Frederic Robra <frederic@robra.org>
 * Parts copyright 2011-2012 Johann Latocha <johann@latocha.de>
 *
 * This file may be licensed under the terms of of the
 * GNU General Public License Version 2 (the ``GPL'').
 *
 * Software distributed under the License is distributed
 * on an ``AS IS'' basis, WITHOUT WARRANTY OF ANY KIND, either
 * express or implied. See the GPL for the specific language
 * governing rights and limitations.
 *
 * You should have received a copy of the GPL along with this
 * program. If not, go to http://www.gnu.org/licenses/gpl.html
 * or write to the Free Software Foundation, Inc.,
 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 *
 */



#include <net/sock.h>
#include <linux/wait.h>

#include "dnbd3.h"
#include "clientconfig.h"


#define DNBD3_REQ_OP_SPECIAL REQ_OP_DRV_IN
#define DNBD3_REQ_OP_CONNECT REQ_OP_DRV_OUT

#define dnbd3_cmd_to_priv(req, cmd)   (req)->cmd_flags = DNBD3_REQ_OP_SPECIAL | ((cmd) << REQ_FLAG_BITS)
#define dnbd3_connect(req)			  (req)->cmd_flags = DNBD3_REQ_OP_CONNECT | ((CMD_SELECT_IMAGE) << REQ_FLAG_BITS)
#define dnbd3_priv_to_cmd(req)        ((req)->cmd_flags >> REQ_FLAG_BITS)
#define dnbd3_sock_create(af,type,proto,sock) sock_create_kern(&init_net, (af) == HOST_IP4 ? AF_INET : AF_INET6, type, proto, sock)

#define KEEPALIVE_TIMER (jiffies + (HZ * TIMER_INTERVAL_KEEPALIVE_PACKET))
#define DISCOVERY_TIMER (jiffies + (HZ * TIMER_INTERVAL_PROBE_NORMAL))

#define init_msghdr(h) do { \
        h.msg_name = NULL; \
        h.msg_namelen = 0; \
        h.msg_control = NULL; \
        h.msg_controllen = 0; \
        h.msg_flags = MSG_WAITALL | MSG_NOSIGNAL; \
	} while (0)

static DECLARE_WAIT_QUEUE_HEAD(send_wq);
static uint64_t send_wq_handle;

static int dnbd3_socket_connect(dnbd3_device *dev, dnbd3_server *server);
static int dnbd3_socket_disconnect(dnbd3_device *dev, dnbd3_server *server, dnbd3_sock *sock);

static void print_host(struct dnbd3_host_t *host, char *msg)
{
	if (host->type == HOST_IP4) {
		printk(KERN_INFO "dnbd3: %s %pI4:%d\n", msg, host->addr, host->port);
	} else {
		printk(KERN_INFO "dnbd3: %s [%pI6]:%d\n", msg, host->addr, host->port);
	}
}

static void print_server_list(struct dnbd3_device *dev)
{
	int i;
	print_host(&dev->initial_server.host, "initial server is");
	for (i = 0; i < NUMBER_SERVERS; i++) {
		if (dev->alt_servers[i].host.addr[0] != 0) {
			print_host(&dev->alt_servers[i].host, "alternative server is");
		}
	}
}

static uint64_t to_handle(uint32_t arg0, uint32_t arg1) {
	return ((uint64_t) arg0 << 32) | arg1;
}

static uint32_t arg0_from_handle(uint64_t handle) {
	return (uint32_t)(handle >> 32);
}

static uint32_t arg1_from_handle(uint64_t handle) {
	return (uint32_t) handle;
}

int dnbd3_send_request(struct dnbd3_sock *sock, struct request *req, struct dnbd3_cmd *cmd)
{
	dnbd3_request_t dnbd3_request;
	struct msghdr msg;
	struct kvec iov[2];
	size_t iov_num = 1;
	size_t send_len;
	int result;
	uint32_t tag;
	uint64_t handle;
	serialized_buffer_t payload_buffer;
	sock->pending = req;
	init_msghdr(msg);

	dnbd3_request.magic = dnbd3_packet_magic;

	switch (req_op(req)) {
	case REQ_OP_READ:
		printk(KERN_DEBUG "dnbd3: request operation read\n");
		dnbd3_request.cmd = CMD_GET_BLOCK;
		dnbd3_request.offset = blk_rq_pos(req) << 9; // *512
		dnbd3_request.size = blk_rq_bytes(req); // bytes left to complete entire request
		break;
	case DNBD3_REQ_OP_SPECIAL:
		printk(KERN_DEBUG "dnbd3: request operation special\n");
		dnbd3_request.cmd = dnbd3_priv_to_cmd(req);
		dnbd3_request.size = 0;
		break;
	case DNBD3_REQ_OP_CONNECT:
		printk(KERN_DEBUG "dnbd3: request operation connect to %s\n", sock->device->imgname);
		dnbd3_request.cmd = CMD_SELECT_IMAGE;
		serializer_reset_write(&payload_buffer);
		serializer_put_uint16(&payload_buffer, PROTOCOL_VERSION);
		serializer_put_string(&payload_buffer, sock->device->imgname);
		serializer_put_uint16(&payload_buffer, sock->device->rid);
		serializer_put_uint8(&payload_buffer, 0); // is_server = false
		iov[1].iov_base = &payload_buffer;
		dnbd3_request.size = iov[1].iov_len = serializer_get_written_length(&payload_buffer);
		iov_num = 2;
		break;
	default:
		return -EIO;
	}
	sock->cookie++;
	if (cmd != NULL) {
		cmd->cookie = sock->cookie;
		tag = blk_mq_unique_tag(req);
		handle = ((uint64_t) tag << 32) | sock->cookie;
	} else {
		handle = sock->cookie;
	}
	memcpy(&dnbd3_request.handle, &handle, sizeof(handle));
	printk(KERN_DEBUG "dnbd3: request handle is %llu\n", dnbd3_request.handle);

//	dnbd3_request.handle = (uint64_t)(uintptr_t)req; // Double cast to prevent warning on 32bit
	fixup_request(dnbd3_request);
	iov[0].iov_base = &dnbd3_request;
	iov[0].iov_len = sizeof(dnbd3_request);
	send_len = iov_num == 1 ? sizeof(dnbd3_request) : iov[0].iov_len + iov[1].iov_len;
	if ((result = kernel_sendmsg(sock->sock, &msg, iov, iov_num, send_len)) != send_len) {
		printk(KERN_ERR "dnbd3: connection to server lost\n");
		goto error;
	}

	sock->pending = NULL;
	result = 0;
error:
	return result;
}


int dnbd3_send_request_blocking(struct dnbd3_sock *sock, int dnbd3_cmd)
{
	int result = 0;
	uint64_t handle;
	struct request *req = kmalloc(sizeof(struct request), GFP_ATOMIC );
	printk(KERN_DEBUG "dnbd3: starting blocking request\n");
	if (!req) {
		printk(KERN_ERR "dnbd3: kmalloc failed\n");
		goto error;
	}

	switch (dnbd3_cmd) {
	case CMD_KEEPALIVE:
	case CMD_GET_SERVERS:
		dnbd3_cmd_to_priv(req, dnbd3_cmd);
		break;
	case CMD_SELECT_IMAGE:
		dnbd3_connect(req);
		break;
	default:
		printk(KERN_WARNING "dnbd3: unsupported command for blocking %d\n", dnbd3_cmd);
		result = -EINVAL;
		goto error;
	}

	mutex_lock(&sock->lock);
	result = dnbd3_send_request(sock, req, NULL);
	if (result) {
		mutex_unlock(&sock->lock);
		goto error;
	}
	send_wq_handle = 0;
	handle = to_handle(sock->device->minor, dnbd3_cmd);

	mutex_unlock(&sock->lock);

	printk(KERN_DEBUG "dnbd3: blocking request going to sleep wait for handle %llu\n", handle);
	wait_event_interruptible(send_wq, handle == send_wq_handle);
	printk(KERN_DEBUG "dnbd3: blocking request woke up with handle %llu\n", handle);


error:
	if (req) {
		kfree(req);
	}
	return result;
}

static void dnbd3_receive_work(struct work_struct *work)
{
	struct dnbd3_sock *sock = container_of(work, struct dnbd3_sock, receive);
	struct dnbd3_device *dev = sock->device;
	struct request *req = NULL;
	dnbd3_reply_t dnbd3_reply;
	struct dnbd3_cmd *cmd;
	struct msghdr msg;
	struct kvec iov;
	struct req_iterator iter;
	struct bio_vec bvec_inst;
	struct bio_vec *bvec = &bvec_inst;
	sigset_t blocked, oldset;
	void *kaddr;
	uint32_t tag, cookie;
	uint16_t hwq;
	int result, count, remaining;
	uint16_t rid;
	uint64_t reported_size, handle;
	char *name;
	serialized_buffer_t payload_buffer;
	init_msghdr(msg);

	while(sock->sock) {
		iov.iov_base = &dnbd3_reply;
		iov.iov_len = sizeof(dnbd3_reply);
		result = kernel_recvmsg(sock->sock, &msg, &iov, 1, sizeof(dnbd3_reply), msg.msg_flags);
		if (!result) {
			printk(KERN_ERR "dnbd3: connection to server lost\n");
			result = -EIO;
			goto error;

		}
		result = 0;
		fixup_reply(dnbd3_reply);

		// check error
		if (dnbd3_reply.magic != dnbd3_packet_magic) {
			printk(KERN_ERR "dnbd3: wrong magic packet\n");
			result = -EIO;
			goto error;
		}

		if (dnbd3_reply.cmd == 0) {
			printk(KERN_ERR "dnbd3: command was 0\n");
			result = -EIO;
			goto error;
		}


		switch (dnbd3_reply.cmd) {
		case CMD_GET_BLOCK:
			printk(KERN_DEBUG "dnbd3: handle is %llu\n", dnbd3_reply.handle);
			memcpy(&handle, &dnbd3_reply.handle, sizeof(handle));
			cookie = (uint32_t) handle;
			tag = (uint32_t)(handle >> 32);

			hwq = blk_mq_unique_tag_to_hwq(tag);
			if (hwq < dev->tag_set.nr_hw_queues)
				req = blk_mq_tag_to_rq(dev->tag_set.tags[hwq], blk_mq_unique_tag_to_tag(tag));
			if (!req || !blk_mq_request_started(req)) {
				dev_err(disk_to_dev(dev->disk), "Unexpected reply (%d) %p\n", tag, req);
				continue;
			}
			cmd = blk_mq_rq_to_pdu(req);

			mutex_lock(&cmd->lock);
			if (cmd->cookie != cookie) {
				dev_err(disk_to_dev(dev->disk), "Double reply on req %p, cookie %u, handle cookie %u\n",
					req, cmd->cookie, cookie);
				mutex_unlock(&cmd->lock);
				continue;
			}


			rq_for_each_segment(bvec_inst, req, iter) {
				siginitsetinv(&blocked, sigmask(SIGKILL));
				sigprocmask(SIG_SETMASK, &blocked, &oldset);

				kaddr = kmap(bvec->bv_page) + bvec->bv_offset;
				iov.iov_base = kaddr;
				iov.iov_len = bvec->bv_len;
				if (kernel_recvmsg(sock->sock, &msg, &iov, 1, bvec->bv_len, msg.msg_flags) != bvec->bv_len) {
					kunmap(bvec->bv_page);
					sigprocmask(SIG_SETMASK, &oldset, NULL );
					printk(KERN_ERR "dnbd3: could not receive form net to block layer\n");
					mutex_unlock(&cmd->lock);
					continue;
				}
				kunmap(bvec->bv_page);

				sigprocmask(SIG_SETMASK, &oldset, NULL );
			}
			mutex_unlock(&cmd->lock);
			blk_mq_end_request(req, 0);
			break;
		case CMD_GET_SERVERS:
			printk(KERN_DEBUG "dnbd3: get servers received\n");
			mutex_lock(&dev->device_lock);
			if (!dev->use_server_provided_alts) {
				remaining = dnbd3_reply.size;
				goto consume_payload;
			}
			dev->new_servers_num = 0;
			count = MIN(NUMBER_SERVERS, dnbd3_reply.size / sizeof(dnbd3_server_entry_t));

			if (count != 0) {
				iov.iov_base = dev->new_servers;
				iov.iov_len = count * sizeof(dnbd3_server_entry_t);
				if (kernel_recvmsg(sock->sock, &msg, &iov, 1, (count * sizeof(dnbd3_server_entry_t)), msg.msg_flags) != (count * sizeof(dnbd3_server_entry_t))) {
					printk(KERN_ERR "dnbd3: failed to get servers\n");
					mutex_unlock(&dev->device_lock);
					goto error;
				}
				dev->new_servers_num = count;
			}
			// If there were more servers than accepted, remove the remaining data from the socket buffer
			remaining = dnbd3_reply.size - (count * sizeof(dnbd3_server_entry_t));
consume_payload:
			while (remaining > 0) {
				count = MIN(sizeof(dnbd3_reply), remaining); // Abuse the reply struct as the receive buffer
				iov.iov_base = &dnbd3_reply;
				iov.iov_len = count;
				result = kernel_recvmsg(sock->sock, &msg, &iov, 1, count, msg.msg_flags);
				if (result <= 0) {
					printk(KERN_ERR "dnbd3: failed to receive payload from get servers\n");
					mutex_unlock(&dev->device_lock);
					goto error;
				}
				result = 0;
			}
			mutex_unlock(&dev->device_lock);
			break;
		case CMD_LATEST_RID:
			if (dnbd3_reply.size != 2) {
				printk(KERN_ERR "dnbd3: failed to get latest rid, wrong size\n");
				goto error;
			}
			printk(KERN_DEBUG "dnbd3: latest rid received\n");
			iov.iov_base = &rid;
			iov.iov_len = sizeof(rid);
			if (kernel_recvmsg(sock->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags) <= 0) {
				printk(KERN_ERR "dnbd3: failed to get latest rid\n");
				goto error;
			}
			rid = net_order_16(rid);
			printk("Latest rid of %s is %d (currently using %d)\n", dev->imgname, (int)rid, (int)dev->rid);
			dev->update_available = (rid > dev->rid ? 1 : 0);
			break;
		case CMD_KEEPALIVE:
			if (dnbd3_reply.size != 0) {
				printk(KERN_ERR "dnbd3: got keep alive packet with payload\n");
				goto error;
			}
			printk(KERN_DEBUG "dnbd3: keep alive received\n");
			break;
		case CMD_SELECT_IMAGE:
			printk(KERN_DEBUG "dnbd3: select image received\n");
			// receive reply payload
			iov.iov_base = &payload_buffer;
			iov.iov_len = dnbd3_reply.size;
			if ((result = kernel_recvmsg(sock->sock, &msg, &iov, 1, dnbd3_reply.size, msg.msg_flags)) != dnbd3_reply.size) {
				printk(KERN_ERR "dnbd3: could not read CMD_SELECT_IMAGE payload on handshake, size is %d and should be%d\n",
						result, dnbd3_reply.size);
				goto error;
			}
			result = 0;

			// handle/check reply payload
			serializer_reset_read(&payload_buffer, dnbd3_reply.size);
			sock->server->protocol_version = serializer_get_uint16(&payload_buffer);
			if (sock->server->protocol_version < MIN_SUPPORTED_SERVER) {
				printk(KERN_ERR "dnbd3: server version is lower than min supported version\n");
				goto error;
			}

			name = serializer_get_string(&payload_buffer);
			rid = serializer_get_uint16(&payload_buffer);
			if (dev->rid != rid && strcmp(name, dev->imgname) != 0) {
				printk(KERN_ERR "dnbd3: server offers image '%s', requested '%s'\n", name, dev->imgname);
				goto error;
			}

			reported_size = serializer_get_uint64(&payload_buffer);
			if (!dev->reported_size) {
				if (reported_size < 4096) {
					printk(KERN_ERR "dnbd3: reported size by server is < 4096\n");
					goto error;
				}
				dev->reported_size = reported_size;
				set_capacity(dev->disk, dev->reported_size >> 9); /* 512 Byte blocks */
			} else if (dev->reported_size != reported_size) {
				printk(KERN_ERR "dnbd3: reported size by server is %llu but should be %llu\n", reported_size, dev->reported_size);
			}

			break;
		default:
			printk(KERN_WARNING "dnbd3: Unknown command (Receive)\n");
			break;
		}
error:
		handle = to_handle(dev->minor, dnbd3_reply.cmd);
		printk(KERN_DEBUG "dnbd3: try to wake up queue with handle %llu\n", handle);
		send_wq_handle = handle;
		wake_up_interruptible(&send_wq);
		if (result) {
			printk(KERN_DEBUG "dnbd3: receive error happened %d\n", result);
			break; //TODO for now need to handle errors
		}
		printk(KERN_DEBUG "dnbd3: receive completed, waiting for next receive\n");
	}
	printk(KERN_DEBUG "dnbd3: receive work queue is stopped\n");
	dnbd3_socket_disconnect(dev, sock->server, sock);
}


void dnbd3_keepalive(struct timer_list *arg)
{
	struct dnbd3_sock *sock = container_of(arg, struct dnbd3_sock, keepalive_timer);
	queue_work(dnbd3_wq, &sock->keepalive);
	sock->keepalive_timer.expires = KEEPALIVE_TIMER;
	add_timer(&sock->keepalive_timer);
}

static void keepalive(struct work_struct *work)
{
	struct dnbd3_sock *sock = container_of(work, struct dnbd3_sock, keepalive);
//	struct request *req;
	printk(KERN_DEBUG "dnbd3: starting keepalive worker\n");
//	mutex_lock(&sock->lock);
//	req = kmalloc(sizeof(struct request), GFP_ATOMIC );
	// send keepalive
//	if (req) {
		dnbd3_send_request_blocking(sock, CMD_KEEPALIVE);
////		kfree(req);
//	} else {
//		printk(KERN_WARNING "dnbd3: could not create keepalive request\n");
//	}
	++sock->heartbeat_count;
//	mutex_unlock(&sock->lock);
}

void dnbd3_discovery(struct timer_list *arg)
{
	struct dnbd3_device *dev = container_of(arg, struct dnbd3_device, discovery_timer);
	queue_work(dnbd3_wq, &dev->discovery);
	dev->discovery_timer.expires = DISCOVERY_TIMER;
	add_timer(&dev->discovery_timer);
}

static void discovery(struct work_struct *work)
{
	struct dnbd3_device *dev = container_of(work, struct dnbd3_device, discovery);
	dnbd3_sock *sock = &dev->socks[0]; // we use the first sock for discovery
//	struct request *req;
	int i, j;
	struct dnbd3_server *existing_server, *free_server, *failed_server;
	dnbd3_server_entry_t *new_server;
	printk(KERN_DEBUG "dnbd3: starting discovery worker\n");
//	mutex_lock(&sock->lock);
//	req = kmalloc(sizeof(struct request), GFP_ATOMIC );
//	// send keepalive
//	if (req) {
//		dnbd3_cmd_to_priv(req, CMD_GET_SERVERS);
		dnbd3_send_request_blocking(sock, CMD_GET_SERVERS);
//		kfree(req);
//	} else {
//		printk(KERN_WARNING "dnbd3: could not create get servers request\n");
//	}
//	mutex_unlock(&sock->lock);

	//TODO wait until something is received

	printk(KERN_DEBUG "dnbd3: new server num is %d\n", dev->new_servers_num);
	if (dev->new_servers_num) {
		mutex_lock(&dev->device_lock);


		for (i = 0; i < dev->new_servers_num; i++) {
			new_server = &dev->new_servers[i];
			if (new_server->host.type == HOST_IP4 || new_server->host.type == HOST_IP6) {
				existing_server = NULL;
				free_server = NULL;
				failed_server = NULL;

				// find servers in alt servers
				for (j = 0; j < NUMBER_SERVERS; j++) {
					if ((new_server->host.type == dev->alt_servers[j].host.type)
					   && (new_server->host.port == dev->alt_servers[j].host.port)
					   && (0 == memcmp(new_server->host.addr, dev->alt_servers[j].host.addr,
							   (new_server->host.type == HOST_IP4 ? 4 : 16)))) 	{

						existing_server = &dev->alt_servers[j];
					} else if (dev->alt_servers[j].host.type == 0) {
						free_server = &dev->alt_servers[j];
					} else if (dev->alt_servers[j].failures > 20) {
						failed_server = &dev->alt_servers[j];
					}
				}

				if (existing_server) {
					if (new_server->failures == 1) { // remove is requested
						print_host(&existing_server->host, "remove server");
						dnbd3_socket_disconnect(dev, existing_server, NULL); // TODO what to do when only one connection?
						existing_server->host.type = 0;
					}
					// ADD, so just reset fail counter
//					existing_server->failures = 0; makes no sense?
					continue;
				} else if (free_server) {
					free_server->host = new_server->host;
				} else if (failed_server) {
					failed_server->host = new_server->host;
					free_server = failed_server;
				} else {
					//no server found to replace
					continue;
				}
				print_host(&free_server->host, "got new alt server");
				free_server->failures = 0;
				free_server->protocol_version = 0;
				free_server->rtts[0] = free_server->rtts[1] = free_server->rtts[2] = free_server->rtts[3] = RTT_UNREACHABLE;
			}
		}
		dev->new_servers_num = 0;
		mutex_unlock(&dev->device_lock);
	}

	// measure rtt for all alt servers
	for (i = 0; i < NUMBER_SERVERS; i++) {

	}
}

static int __dnbd3_socket_connect(dnbd3_server * server, dnbd3_sock *sock)
{
	struct timeval timeout;
	mutex_init(&sock->lock);
	mutex_lock(&sock->lock);

	if (server->host.port == 0 || server->host.type == 0) {
		printk(KERN_ERR "dnbd3: host or port not set\n");
		goto error;
	}
	if (sock->sock) {
		printk(KERN_WARNING "dnbd3: socket already connected\n");
		goto error;
	}

	timeout.tv_sec = SOCKET_TIMEOUT_CLIENT_DATA;
	timeout.tv_usec = 0;

	if (dnbd3_sock_create(server->host.type, SOCK_STREAM, IPPROTO_TCP, &sock->sock) < 0) {
		printk(KERN_ERR "dnbd3: could not create socket\n");
		goto error;
	}

	kernel_setsockopt(sock->sock, SOL_SOCKET, SO_SNDTIMEO, (char *)&timeout, sizeof(timeout));
	kernel_setsockopt(sock->sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&timeout, sizeof(timeout));
	sock->sock->sk->sk_allocation = GFP_NOIO;
	if (server->host.type == HOST_IP4) {
		struct sockaddr_in sin;
		memset(&sin, 0, sizeof(sin));
		sin.sin_family = AF_INET;
		memcpy(&(sin.sin_addr), server->host.addr, 4);
		sin.sin_port = server->host.port;
		if (kernel_connect(sock->sock, (struct sockaddr *)&sin, sizeof(sin), 0) != 0) {
			printk(KERN_ERR "dnbd3: connection to host failed (ipv4)\n");
			goto error;
		}
	} else {
		struct sockaddr_in6 sin;
		memset(&sin, 0, sizeof(sin));
		sin.sin6_family = AF_INET6;
		memcpy(&(sin.sin6_addr), server->host.addr, 16);
		sin.sin6_port = server->host.port;
		if (kernel_connect(sock->sock, (struct sockaddr *)&sin, sizeof(sin), 0) != 0){
			printk(KERN_ERR "dnbd3: connection to host failed (ipv6)\n");
			goto error;
		}
	}
	mutex_unlock(&sock->lock);

	return 0;
error:
	if (sock->sock) {
		sock_release(sock->sock);
		sock->sock = NULL;
	}
	mutex_unlock(&sock->lock);
	mutex_destroy(&sock->lock);
	return -EIO;
}

static int dnbd3_socket_connect(dnbd3_device *dev, dnbd3_server *server)
{
	int i;
	int result = -EIO;
	struct dnbd3_sock *sock = NULL;
	struct request *req = NULL;
	for (i = 0; i < NUMBER_CONNECTIONS; i++) {
		if (!dev->socks[i].sock) {
			sock = &dev->socks[i];
			break;
		}
	}
	if (sock == NULL) {
		printk(KERN_WARNING "dnbd3: could not connect to socket, to many connections\n");
		return -EIO;
	}
	sock->server = server;

	printk(KERN_DEBUG "dnbd3: socket connect device %i\n", dev->minor);

	__dnbd3_socket_connect(server, sock);

//	mutex_lock(&sock->lock);
//	req = kmalloc(sizeof(*req), GFP_ATOMIC );
//	if (!req) {
//		printk(KERN_ERR "dnbd3: kmalloc failed\n");
//		goto error;
//	}

	// start the receiver
	INIT_WORK(&sock->receive, dnbd3_receive_work);
	queue_work(dnbd3_wq, &sock->receive);

//	dnbd3_connect(req);
	result = dnbd3_send_request_blocking(sock, CMD_SELECT_IMAGE);
	if (result) {
		printk(KERN_ERR "dnbd3: connection to image %s failed\n", dev->imgname);
		goto error;

	}
//	mutex_unlock(&sock->lock);

	//TODO wait until connected

	printk(KERN_DEBUG "dnbd3: connected to image %s, filesize %llu\n", dev->imgname, dev->reported_size);

	// add heartbeat timer and scheduler for the command
	INIT_WORK(&sock->keepalive, keepalive);
	sock->heartbeat_count = 0;
	timer_setup(&sock->keepalive_timer, dnbd3_keepalive, 0);
	sock->keepalive_timer.expires = KEEPALIVE_TIMER;
	add_timer(&sock->keepalive_timer);

	mutex_unlock(&sock->lock);

//	kfree(req);
	return 0;
error:
	if (sock->sock) {
		sock_release(sock->sock);
		sock->sock = NULL;
	}
	if (req) {
		kfree(req);
	}
	mutex_unlock(&sock->lock);
	return result;
}


static int dnbd3_socket_disconnect(dnbd3_device *dev, dnbd3_server *server, dnbd3_sock *sock)
{
	int i;
	if (sock == NULL) {
		for (i = 0; i < NUMBER_CONNECTIONS; i++) {
			if (dev->socks[i].server == server) {
				sock = &dev->socks[i];
				break;
			}
		}
		if (!sock) {
			printk(KERN_WARNING "dnbd3: could not find socket to disconnect\n");
			return -EIO;
		}
	}
	printk(KERN_DEBUG "dnbd3: socket disconnect device %i\n", dev->minor);
	mutex_lock(&sock->lock);

	// clear heartbeat timer
	del_timer_sync(&sock->keepalive_timer);

	if (sock->sock) {
		kernel_sock_shutdown(sock->sock, SHUT_RDWR);
	}

	// clear socket
	if (sock->sock) {
		sock_release(sock->sock);
		sock->sock = NULL;
	}

	mutex_unlock(&sock->lock);
	mutex_destroy(&sock->lock);
	sock->server = NULL;
	return 0;
}

int dnbd3_net_disconnect(struct dnbd3_device *dev)
{
	int i;
	int result;
	del_timer_sync(&dev->discovery_timer);
	for (i = 0; i < NUMBER_CONNECTIONS; i++) {
		if (dev->socks[i].sock) {
			if (dnbd3_socket_disconnect(dev, NULL, &dev->socks[i])) {
				result = -EIO;
			}
		}
	}
	return result;
}


int dnbd3_net_connect(struct dnbd3_device *dev) {
	// TODO decide which socket to connect
	int result;
	dev->socks_active = 0;
	if (dnbd3_socket_connect(dev, &dev->alt_servers[0]) == 0) {
		print_server_list(dev);

		INIT_WORK(&dev->discovery, discovery);
		timer_setup(&dev->discovery_timer, dnbd3_discovery, 0);
		dev->discovery_timer.expires = DISCOVERY_TIMER;
		add_timer(&dev->discovery_timer);

		// let it discover alt servers
		queue_work(dnbd3_wq, &dev->discovery);

		result = 0;
	} else {
		printk(KERN_ERR "dnbd3: failed to connect to initial server\n");
		result = -ENOENT;
		dev->imgname = NULL;
		dev->socks[0].server = NULL;
	}
	return result;
}