diff options
-rw-r--r-- | src/kernel/net.c | 209 |
1 files changed, 93 insertions, 116 deletions
diff --git a/src/kernel/net.c b/src/kernel/net.c index 921f0d3..0c3993d 100644 --- a/src/kernel/net.c +++ b/src/kernel/net.c @@ -26,9 +26,13 @@ #include "dnbd3.h" #include "clientconfig.h" -#define dnbd3_cmd_to_priv(req, cmd) (req)->cmd_flags = REQ_OP_DRV_IN | ((cmd) << REQ_FLAG_BITS) + +#define DNBD3_REQ_OP_SPECIAL REQ_OP_DRV_IN +#define DNBD3_REQ_OP_CONNECT REQ_OP_DRV_OUT + +#define dnbd3_cmd_to_priv(req, cmd) (req)->cmd_flags = DNBD3_REQ_OP_SPECIAL | ((cmd) << REQ_FLAG_BITS) +#define dnbd3_connect(req) (req)->cmd_flags = DNBD3_REQ_OP_CONNECT | ((CMD_SELECT_IMAGE) << REQ_FLAG_BITS) #define dnbd3_priv_to_cmd(req) ((req)->cmd_flags >> REQ_FLAG_BITS) -#define dnbd3_req_op(req) req_op(req) #define dnbd3_sock_create(af,type,proto,sock) sock_create_kern(&init_net, (af) == HOST_IP4 ? AF_INET : AF_INET6, type, proto, sock) #define KEEPALIVE_TIMER (jiffies + (HZ * TIMER_INTERVAL_KEEPALIVE_PACKET)) @@ -69,7 +73,9 @@ int dnbd3_send_request(struct dnbd3_device *dev, struct dnbd3_sock *sock, struct dnbd3_request_t dnbd3_request; dnbd3_reply_t dnbd3_reply; struct msghdr msg; - struct kvec iov; + struct kvec iov[2]; + size_t iov_num = 1; + size_t send_len; struct req_iterator iter; struct bio_vec bvec_inst; struct bio_vec *bvec = &bvec_inst; @@ -77,51 +83,57 @@ int dnbd3_send_request(struct dnbd3_device *dev, struct dnbd3_sock *sock, struct void *kaddr; int result, count, remaining; uint16_t rid; + uint64_t reported_size; + char *name; + serialized_buffer_t payload_buffer; sock->pending = req; init_msghdr(msg); dnbd3_request.magic = dnbd3_packet_magic; switch (req_op(req)) { -// case REQ_OP_DISCARD: -// printk(KERN_DEBUG "dnbd3: request operation discard on device %d\n", dev->minor); -// break; -// case REQ_OP_FLUSH: -// printk(KERN_DEBUG "dnbd3: request operation flush on device %d\n", dev->minor); -// break; -// case REQ_OP_WRITE: -// printk(KERN_DEBUG "dnbd3: request operation write on device %d\n", dev->minor); -// break; case REQ_OP_READ: printk(KERN_DEBUG "dnbd3: request operation read\n"); dnbd3_request.cmd = CMD_GET_BLOCK; dnbd3_request.offset = blk_rq_pos(req) << 9; // *512 dnbd3_request.size = blk_rq_bytes(req); // bytes left to complete entire request break; - case REQ_OP_DRV_IN: - printk(KERN_DEBUG "dnbd3: request operation driver in\n"); + case DNBD3_REQ_OP_SPECIAL: + printk(KERN_DEBUG "dnbd3: request operation special\n"); dnbd3_request.cmd = dnbd3_priv_to_cmd(req); dnbd3_request.size = 0; break; - + case DNBD3_REQ_OP_CONNECT: + printk(KERN_DEBUG "dnbd3: request operation connect to %s\n", dev->imgname); + dnbd3_request.cmd = CMD_SELECT_IMAGE; + serializer_reset_write(&payload_buffer); + serializer_put_uint16(&payload_buffer, PROTOCOL_VERSION); + serializer_put_string(&payload_buffer, dev->imgname); + serializer_put_uint16(&payload_buffer, dev->rid); + serializer_put_uint8(&payload_buffer, 0); // is_server = false + iov[1].iov_base = &payload_buffer; + dnbd3_request.size = iov[1].iov_len = serializer_get_written_length(&payload_buffer); + iov_num = 2; + break; default: return -EIO; } dnbd3_request.handle = (uint64_t)(uintptr_t)req; // Double cast to prevent warning on 32bit fixup_request(dnbd3_request); - iov.iov_base = &dnbd3_request; - iov.iov_len = sizeof(dnbd3_request); - if (kernel_sendmsg(sock->sock, &msg, &iov, 1, sizeof(dnbd3_request)) != sizeof(dnbd3_request)) { + iov[0].iov_base = &dnbd3_request; + iov[0].iov_len = sizeof(dnbd3_request); + send_len = iov_num == 1 ? sizeof(dnbd3_request) : iov[0].iov_len + iov[1].iov_len; + if (kernel_sendmsg(sock->sock, &msg, iov, iov_num, send_len) != send_len) { printk(KERN_ERR "dnbd3: connection to server lost\n"); result = -EIO; goto error; } // receive net reply - iov.iov_base = &dnbd3_reply; - iov.iov_len = sizeof(dnbd3_reply); - result = kernel_recvmsg(sock->sock, &msg, &iov, 1, sizeof(dnbd3_reply), msg.msg_flags); + iov[0].iov_base = &dnbd3_reply; + iov[0].iov_len = sizeof(dnbd3_reply); + result = kernel_recvmsg(sock->sock, &msg, iov, 1, sizeof(dnbd3_reply), msg.msg_flags); if (!result) { printk(KERN_ERR "dnbd3: connection to server lost\n"); result = -EIO; @@ -151,9 +163,9 @@ int dnbd3_send_request(struct dnbd3_device *dev, struct dnbd3_sock *sock, struct sigprocmask(SIG_SETMASK, &blocked, &oldset); kaddr = kmap(bvec->bv_page) + bvec->bv_offset; - iov.iov_base = kaddr; - iov.iov_len = bvec->bv_len; - if (kernel_recvmsg(sock->sock, &msg, &iov, 1, bvec->bv_len, msg.msg_flags) != bvec->bv_len) { + iov[0].iov_base = kaddr; + iov[0].iov_len = bvec->bv_len; + if (kernel_recvmsg(sock->sock, &msg, iov, 1, bvec->bv_len, msg.msg_flags) != bvec->bv_len) { kunmap(bvec->bv_page); sigprocmask(SIG_SETMASK, &oldset, NULL ); printk(KERN_ERR "dnbd3: could not receive form net to block layer\n"); @@ -176,9 +188,9 @@ int dnbd3_send_request(struct dnbd3_device *dev, struct dnbd3_sock *sock, struct count = MIN(NUMBER_SERVERS, dnbd3_reply.size / sizeof(dnbd3_server_entry_t)); if (count != 0) { - iov.iov_base = dev->new_servers; - iov.iov_len = count * sizeof(dnbd3_server_entry_t); - if (kernel_recvmsg(sock->sock, &msg, &iov, 1, (count * sizeof(dnbd3_server_entry_t)), msg.msg_flags) != (count * sizeof(dnbd3_server_entry_t))) { + iov[0].iov_base = dev->new_servers; + iov[0].iov_len = count * sizeof(dnbd3_server_entry_t); + if (kernel_recvmsg(sock->sock, &msg, iov, 1, (count * sizeof(dnbd3_server_entry_t)), msg.msg_flags) != (count * sizeof(dnbd3_server_entry_t))) { printk(KERN_ERR "dnbd3: failed to get servers\n"); mutex_unlock(&dev->device_lock); goto error; @@ -190,9 +202,9 @@ int dnbd3_send_request(struct dnbd3_device *dev, struct dnbd3_sock *sock, struct consume_payload: while (remaining > 0) { count = MIN(sizeof(dnbd3_reply), remaining); // Abuse the reply struct as the receive buffer - iov.iov_base = &dnbd3_reply; - iov.iov_len = count; - result = kernel_recvmsg(sock->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags); + iov[0].iov_base = &dnbd3_reply; + iov[0].iov_len = count; + result = kernel_recvmsg(sock->sock, &msg, iov, 1, count, msg.msg_flags); if (result <= 0) { printk(KERN_ERR "dnbd3: failed to receive payload from get servers\n"); mutex_unlock(&dev->device_lock); @@ -208,9 +220,9 @@ consume_payload: goto error; } printk(KERN_DEBUG "dnbd3: latest rid received\n"); - iov.iov_base = &rid; - iov.iov_len = sizeof(rid); - if (kernel_recvmsg(sock->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags) <= 0) { + iov[0].iov_base = &rid; + iov[0].iov_len = sizeof(rid); + if (kernel_recvmsg(sock->sock, &msg, iov, 1, iov[0].iov_len, msg.msg_flags) <= 0) { printk(KERN_ERR "dnbd3: failed to get latest rid\n"); goto error; } @@ -225,13 +237,51 @@ consume_payload: } printk(KERN_DEBUG "dnbd3: keep alive received\n"); break; + case CMD_SELECT_IMAGE: + printk(KERN_DEBUG "dnbd3: select image received\n"); + // receive reply payload + iov[0].iov_base = &payload_buffer; + iov[0].iov_len = dnbd3_reply.size; + if (kernel_recvmsg(sock->sock, &msg, iov, 1, dnbd3_reply.size, msg.msg_flags) != dnbd3_reply.size) { + printk(KERN_ERR "dnbd3: could not read CMD_SELECT_IMAGE payload on handshake\n"); + goto error; + } + + // handle/check reply payload + serializer_reset_read(&payload_buffer, dnbd3_reply.size); + sock->server->protocol_version = serializer_get_uint16(&payload_buffer); + if (sock->server->protocol_version < MIN_SUPPORTED_SERVER) { + printk(KERN_ERR "dnbd3: server version is lower than min supported version\n"); + goto error; + } + + name = serializer_get_string(&payload_buffer); + rid = serializer_get_uint16(&payload_buffer); + if (dev->rid != rid && strcmp(name, dev->imgname) != 0) { + printk(KERN_ERR "dnbd3: server offers image '%s', requested '%s'\n", name, dev->imgname); + goto error; + } + reported_size = serializer_get_uint64(&payload_buffer); + if (dev->reported_size == NULL) { + if (reported_size < 4096) { + printk(KERN_ERR "dnbd3: reported size by server is < 4096\n"); + goto error; + } + dev->reported_size = reported_size; + set_capacity(dev->disk, dev->reported_size >> 9); /* 512 Byte blocks */ + } else if (dev->reported_size != reported_size) { + printk(KERN_ERR "dnbd3: reported size by server is %llu but should be %llu\n", reported_size, dev->reported_size); + } + + break; default: printk("ERROR: Unknown command (Receive)\n"); break; } sock->pending = NULL; + result = 0; error: return result; } @@ -296,18 +346,9 @@ static void discovery(struct work_struct *work) static int dnbd3_socket_connect(dnbd3_device *dev, dnbd3_sock *sock) { int result = -EIO; - struct request *req1 = NULL; + struct request *req = NULL; struct timeval timeout; struct dnbd3_server *server = sock->server; - dnbd3_request_t dnbd3_request; - dnbd3_reply_t dnbd3_reply; - struct msghdr msg; - struct kvec iov[2]; - uint16_t rid; - uint64_t reported_size; - char *name; - int mlen; - serialized_buffer_t payload_buffer; printk(KERN_DEBUG "dnbd3: socket connect device %i\n", dev->minor); @@ -330,13 +371,12 @@ static int dnbd3_socket_connect(dnbd3_device *dev, dnbd3_sock *sock) timeout.tv_sec = SOCKET_TIMEOUT_CLIENT_DATA; timeout.tv_usec = 0; - req1 = kmalloc(sizeof(*req1), GFP_ATOMIC ); - if (!req1) { + req = kmalloc(sizeof(*req), GFP_ATOMIC ); + if (!req) { printk(KERN_ERR "dnbd3: kmalloc failed\n"); goto error; } - init_msghdr(msg); if (dnbd3_sock_create(server->host.type, SOCK_STREAM, IPPROTO_TCP, &sock->sock) < 0) { printk(KERN_ERR "dnbd3: could not create socket\n"); @@ -367,80 +407,16 @@ static int dnbd3_socket_connect(dnbd3_device *dev, dnbd3_sock *sock) goto error; } } - // Request filesize - dnbd3_request.magic = dnbd3_packet_magic; - dnbd3_request.cmd = CMD_SELECT_IMAGE; - iov[0].iov_base = &dnbd3_request; - iov[0].iov_len = sizeof(dnbd3_request); - serializer_reset_write(&payload_buffer); - serializer_put_uint16(&payload_buffer, PROTOCOL_VERSION); - serializer_put_string(&payload_buffer, dev->imgname); - serializer_put_uint16(&payload_buffer, dev->rid); - serializer_put_uint8(&payload_buffer, 0); // is_server = false - iov[1].iov_base = &payload_buffer; - dnbd3_request.size = iov[1].iov_len = serializer_get_written_length(&payload_buffer); - fixup_request(dnbd3_request); - mlen = sizeof(dnbd3_request) + iov[1].iov_len; - if (kernel_sendmsg(sock->sock, &msg, iov, 2, mlen) != mlen) { - printk(KERN_ERR "dnbd3: could not send CMD_SIZE_REQUEST\n"); - goto error; - } - // receive reply header - iov[0].iov_base = &dnbd3_reply; - iov[0].iov_len = sizeof(dnbd3_reply); - if (kernel_recvmsg(sock->sock, &msg, iov, 1, sizeof(dnbd3_reply), msg.msg_flags) != sizeof(dnbd3_reply)) { - printk(KERN_ERR "dnbd3: received corrupted reply header after CMD_SIZE_REQUEST\n"); - goto error; - } - - // check reply header - fixup_reply(dnbd3_reply); - if (dnbd3_reply.cmd != CMD_SELECT_IMAGE || - dnbd3_reply.size < 3 || - dnbd3_reply.size > MAX_PAYLOAD || - dnbd3_reply.magic != dnbd3_packet_magic) { - printk(KERN_ERR "dnbd3: received invalid reply to CMD_SIZE_REQUEST image does not exist on server\n"); - goto error; - } - - // receive reply payload - iov[0].iov_base = &payload_buffer; - iov[0].iov_len = dnbd3_reply.size; - if (kernel_recvmsg(sock->sock, &msg, iov, 1, dnbd3_reply.size, msg.msg_flags) != dnbd3_reply.size) { - printk(KERN_ERR "dnbd3: could not read CMD_SELECT_IMAGE payload on handshake\n"); + dnbd3_connect(req); + result = dnbd3_send_request(dev, sock, req); + if (result) { + printk(KERN_ERR "dnbd3: connection to image %s failed\n", dev->imgname); goto error; - } - // handle/check reply payload - serializer_reset_read(&payload_buffer, dnbd3_reply.size); - server->protocol_version = serializer_get_uint16(&payload_buffer); - if (server->protocol_version < MIN_SUPPORTED_SERVER) { - printk(KERN_ERR "dnbd3: server version is lower than min supported version\n"); - goto error; - } - - name = serializer_get_string(&payload_buffer); - rid = serializer_get_uint16(&payload_buffer); - if (dev->rid != rid && strcmp(name, dev->imgname) != 0) { - printk(KERN_ERR "dnbd3: server offers image '%s', requested '%s'\n", name, dev->imgname); - goto error; - } - - reported_size = serializer_get_uint64(&payload_buffer); - if (dev->reported_size == NULL) { - if (reported_size < 4096) { - printk(KERN_ERR "dnbd3: reported size by server is < 4096\n"); - goto error; - } - dev->reported_size = reported_size; - set_capacity(dev->disk, dev->reported_size >> 9); /* 512 Byte blocks */ - } else if (dev->reported_size != reported_size) { - printk(KERN_ERR "dnbd3: reported size by server is %llu but should be %llu\n", reported_size, dev->reported_size); } printk(KERN_DEBUG "dnbd3: connected to image %s, filesize %llu\n", dev->imgname, dev->reported_size); -// TODO add heartbeat // add heartbeat timer and scheduler for the command INIT_WORK(&sock->keepalive, keepalive); sock->heartbeat_count = 0; @@ -450,14 +426,15 @@ static int dnbd3_socket_connect(dnbd3_device *dev, dnbd3_sock *sock) mutex_unlock(&sock->lock); + kfree(req); return 0; error: if (sock->sock) { sock_release(sock->sock); sock->sock = NULL; } - if (req1) { - kfree(req1); + if (req) { + kfree(req); } mutex_unlock(&sock->lock); return result; |