From f9ec2db3b4d1e0047087393218618cf8c439c336 Mon Sep 17 00:00:00 2001 From: Frederic Robra Date: Sun, 7 Jul 2019 22:13:01 +0200 Subject: added first draft for keepalive and discovery --- src/kernel/net.c | 528 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 528 insertions(+) create mode 100644 src/kernel/net.c (limited to 'src/kernel/net.c') diff --git a/src/kernel/net.c b/src/kernel/net.c new file mode 100644 index 0000000..f44925f --- /dev/null +++ b/src/kernel/net.c @@ -0,0 +1,528 @@ +/* + * This file is part of the Distributed Network Block Device 3 + * + * Copyright(c) 2019 Frederic Robra + * Parts copyright 2011-2012 Johann Latocha + * + * This file may be licensed under the terms of of the + * GNU General Public License Version 2 (the ``GPL''). + * + * Software distributed under the License is distributed + * on an ``AS IS'' basis, WITHOUT WARRANTY OF ANY KIND, either + * express or implied. See the GPL for the specific language + * governing rights and limitations. + * + * You should have received a copy of the GPL along with this + * program. If not, go to http://www.gnu.org/licenses/gpl.html + * or write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + * + */ + + + +#include + +#include "dnbd3.h" +#include "clientconfig.h" + +#define dnbd3_cmd_to_priv(req, cmd) (req)->cmd_flags = REQ_OP_DRV_IN | ((cmd) << REQ_FLAG_BITS) +#define dnbd3_priv_to_cmd(req) ((req)->cmd_flags >> REQ_FLAG_BITS) +#define dnbd3_req_op(req) req_op(req) +#define dnbd3_sock_create(af,type,proto,sock) sock_create_kern(&init_net, (af) == HOST_IP4 ? AF_INET : AF_INET6, type, proto, sock) + +#define KEEPALIVE_TIMER (TIMER_INTERVAL_KEEPALIVE_PACKET * (jiffies + HZ)) +#define DISCOVERY_TIMER (TIMER_INTERVAL_PROBE_NORMAL * (jiffies + HZ)) + +#define init_msghdr(h) do { \ + h.msg_name = NULL; \ + h.msg_namelen = 0; \ + h.msg_control = NULL; \ + h.msg_controllen = 0; \ + h.msg_flags = MSG_WAITALL | MSG_NOSIGNAL; \ + } while (0) + + +static void printHost(struct dnbd3_host_t *host, char *msg) +{ + if (host->type == HOST_IP4) { + printk(KERN_INFO "dnbd3: %s %pI4:%d\n", msg, host->addr, host->port); + } else { + printk(KERN_INFO "dnbd3: %s [%pI6]:%d\n", msg, host->addr, host->port); + } +} + +static void printServerList(struct dnbd3_device *dev) +{ + int i; + printHost(&dev->initial_server.host, "initial server is"); + for (i = 0; i < NUMBER_SERVERS; i++) { + if (dev->alt_servers[i].host.addr[0] != 0) { + printHost(&dev->alt_servers[i].host, "alternative server is"); + } + } +} + +int dnbd3_send_request(struct dnbd3_device *dev, struct dnbd3_sock *sock, struct request *req) +{ + dnbd3_request_t dnbd3_request; + dnbd3_reply_t dnbd3_reply; + struct msghdr msg; + struct kvec iov; + struct req_iterator iter; + struct bio_vec bvec_inst; + struct bio_vec *bvec = &bvec_inst; + sigset_t blocked, oldset; + void *kaddr; + int result, count, remaining; + uint16_t rid; + sock->pending = req; + init_msghdr(msg); + + dnbd3_request.magic = dnbd3_packet_magic; + + switch (req_op(req)) { +// case REQ_OP_DISCARD: +// printk(KERN_DEBUG "dnbd3: request operation discard on device %d\n", dev->minor); +// break; +// case REQ_OP_FLUSH: +// printk(KERN_DEBUG "dnbd3: request operation flush on device %d\n", dev->minor); +// break; +// case REQ_OP_WRITE: +// printk(KERN_DEBUG "dnbd3: request operation write on device %d\n", dev->minor); +// break; + case REQ_OP_READ: + printk(KERN_DEBUG "dnbd3: request operation read\n"); + dnbd3_request.cmd = CMD_GET_BLOCK; + dnbd3_request.offset = blk_rq_pos(req) << 9; // *512 + dnbd3_request.size = blk_rq_bytes(req); // bytes left to complete entire request + break; + case REQ_OP_DRV_IN: + printk(KERN_DEBUG "dnbd3: request operation driver in\n"); + dnbd3_request.cmd = dnbd3_priv_to_cmd(req); + dnbd3_request.size = 0; + break; + + default: + return -EIO; + } + + dnbd3_request.handle = (uint64_t)(uintptr_t)req; // Double cast to prevent warning on 32bit + fixup_request(dnbd3_request); + iov.iov_base = &dnbd3_request; + iov.iov_len = sizeof(dnbd3_request); + if (kernel_sendmsg(sock->sock, &msg, &iov, 1, sizeof(dnbd3_request)) != sizeof(dnbd3_request)) { + printk(KERN_ERR "dnbd3: connection to server lost\n"); + result = -EIO; + goto error; + } + + // receive net reply + iov.iov_base = &dnbd3_reply; + iov.iov_len = sizeof(dnbd3_reply); + result = kernel_recvmsg(sock->sock, &msg, &iov, 1, sizeof(dnbd3_reply), msg.msg_flags); + if (!result) { + printk(KERN_ERR "dnbd3: connection to server lost\n"); + result = -EIO; + goto error; + + } + fixup_reply(dnbd3_reply); + + // check error + if (dnbd3_reply.magic != dnbd3_packet_magic) { + printk(KERN_ERR "dnbd3: wrong magic packet\n"); + result = -EIO; + goto error; + } + + if (dnbd3_reply.cmd == 0) { + printk(KERN_ERR "dnbd3: command was 0\n"); + result = -EIO; + goto error; + } + + + switch (dnbd3_reply.cmd) { + case CMD_GET_BLOCK: + rq_for_each_segment(bvec_inst, req, iter) { + siginitsetinv(&blocked, sigmask(SIGKILL)); + sigprocmask(SIG_SETMASK, &blocked, &oldset); + + kaddr = kmap(bvec->bv_page) + bvec->bv_offset; + iov.iov_base = kaddr; + iov.iov_len = bvec->bv_len; + if (kernel_recvmsg(sock->sock, &msg, &iov, 1, bvec->bv_len, msg.msg_flags) != bvec->bv_len) { + kunmap(bvec->bv_page); + sigprocmask(SIG_SETMASK, &oldset, NULL ); + printk(KERN_ERR "dnbd3: could not receive form net to block layer\n"); + goto error; + } + kunmap(bvec->bv_page); + + sigprocmask(SIG_SETMASK, &oldset, NULL ); + } + blk_mq_end_request(req, 0); + break; + case CMD_GET_SERVERS: + printk(KERN_DEBUG "dnbd3: get servers received\n"); + mutex_lock(&dev->device_lock); + if (!dev->use_server_provided_alts) { + remaining = dnbd3_reply.size; + goto consume_payload; + } + dev->new_servers_num = 0; + count = MIN(NUMBER_SERVERS, dnbd3_reply.size / sizeof(dnbd3_server_entry_t)); + + if (count != 0) { + iov.iov_base = dev->new_servers; + iov.iov_len = count * sizeof(dnbd3_server_entry_t); + if (kernel_recvmsg(sock->sock, &msg, &iov, 1, (count * sizeof(dnbd3_server_entry_t)), msg.msg_flags) != (count * sizeof(dnbd3_server_entry_t))) { + printk(KERN_ERR "dnbd3: failed to get servers\n"); + mutex_unlock(&dev->device_lock); + goto error; + } + dev->new_servers_num = count; + } + // If there were more servers than accepted, remove the remaining data from the socket buffer + remaining = dnbd3_reply.size - (count * sizeof(dnbd3_server_entry_t)); +consume_payload: + while (remaining > 0) { + count = MIN(sizeof(dnbd3_reply), remaining); // Abuse the reply struct as the receive buffer + iov.iov_base = &dnbd3_reply; + iov.iov_len = count; + result = kernel_recvmsg(sock->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags); + if (result <= 0) { + printk(KERN_ERR "dnbd3: failed to receive payload from get servers\n"); + mutex_unlock(&dev->device_lock); + goto error; + } + result = 0; + } + mutex_unlock(&dev->device_lock); + break; + case CMD_LATEST_RID: + if (dnbd3_reply.size != 2) { + printk(KERN_ERR "dnbd3: failed to get latest rid, wrong size\n"); + goto error; + } + printk(KERN_DEBUG "dnbd3: latest rid received\n"); + iov.iov_base = &rid; + iov.iov_len = sizeof(rid); + if (kernel_recvmsg(sock->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags) <= 0) { + printk(KERN_ERR "dnbd3: failed to get latest rid\n"); + goto error; + } + rid = net_order_16(rid); + printk("Latest rid of %s is %d (currently using %d)\n", dev->imgname, (int)rid, (int)dev->rid); + dev->update_available = (rid > dev->rid ? 1 : 0); + break; + case CMD_KEEPALIVE: + if (dnbd3_reply.size != 0) { + printk(KERN_ERR "dnbd3: got keep alive packet with payload\n"); + goto error; + } + printk(KERN_DEBUG "dnbd3: keep alive received\n"); + break; + + default: + printk("ERROR: Unknown command (Receive)\n"); + break; + + } + sock->pending = NULL; +error: + return result; +} + + +void dnbd3_keepalive(struct timer_list *arg) +{ + struct dnbd3_sock *sock = container_of(arg, struct dnbd3_sock, keepalive_timer); + printk(KERN_DEBUG "dnbd3: schedule keepalive\n"); +// schedule_work(&sock->keepalive); + sock->keepalive_timer.expires = KEEPALIVE_TIMER; + add_timer(&sock->keepalive_timer); +} + +static void keepalive(struct work_struct *work) +{ + struct dnbd3_sock *sock = container_of(work, struct dnbd3_sock, keepalive); + struct request *req; + mutex_lock(&sock->lock); + req = kmalloc(sizeof(struct request), GFP_ATOMIC ); + // send keepalive + if (req) { + dnbd3_cmd_to_priv(req, CMD_KEEPALIVE); + dnbd3_send_request(NULL, sock, req); // we do not need the device for keepalive + kfree(req); + } else { + printk(KERN_WARNING "dnbd3: could not create keepalive request\n"); + } + ++sock->heartbeat_count; + mutex_unlock(&sock->lock); +} + +void dnbd3_discovery(struct timer_list *arg) +{ + struct dnbd3_device *dev = container_of(arg, struct dnbd3_device, discovery_timer); + printk(KERN_DEBUG "dnbd3: schedule discovery\n"); +// schedule_work(&dev->discovery); + dev->discovery_timer.expires = DISCOVERY_TIMER; + add_timer(&dev->discovery_timer); +} + +static void discovery(struct work_struct *work) +{ + struct dnbd3_device *dev = container_of(work, struct dnbd3_device, discovery); + dnbd3_sock *sock = &dev->socks[0]; // we use the first sock for discovery + struct request *req; + mutex_lock(&sock->lock); + req = kmalloc(sizeof(struct request), GFP_ATOMIC ); + // send keepalive + if (req) { + dnbd3_cmd_to_priv(req, CMD_GET_SERVERS); + dnbd3_send_request(NULL, sock, req); // we do not need the device for keepalive + kfree(req); + } else { + printk(KERN_WARNING "dnbd3: could not create get servers request\n"); + } + mutex_unlock(&sock->lock); +} + + +static int dnbd3_socket_connect(dnbd3_device *dev, dnbd3_sock *sock) +{ + int result = -EIO; + struct request *req1 = NULL; + struct timeval timeout; + struct dnbd3_server *server = sock->server; + dnbd3_request_t dnbd3_request; + dnbd3_reply_t dnbd3_reply; + struct msghdr msg; + struct kvec iov[2]; + uint16_t rid; + uint64_t reported_size; + char *name; + int mlen; + serialized_buffer_t payload_buffer; + + printk(KERN_DEBUG "dnbd3: socket connect device %i\n", dev->minor); + + mutex_init(&sock->lock); + mutex_lock(&sock->lock); + if (sock->pending) { + printk(KERN_DEBUG "dnbd3: socket still in request\n"); + while (sock->pending) + schedule(); + } + if (server->host.port == 0 || server->host.type == 0) { + printk(KERN_ERR "dnbd3: host or port not set\n"); + goto error; + } + if (sock->sock) { + printk(KERN_WARNING "dnbd3: socket already connected\n"); + goto error; + } + + timeout.tv_sec = SOCKET_TIMEOUT_CLIENT_DATA; + timeout.tv_usec = 0; + + req1 = kmalloc(sizeof(*req1), GFP_ATOMIC ); + if (!req1) { + printk(KERN_ERR "dnbd3: kmalloc failed\n"); + goto error; + } + + init_msghdr(msg); + + if (dnbd3_sock_create(server->host.type, SOCK_STREAM, IPPROTO_TCP, &sock->sock) < 0) { + printk(KERN_ERR "dnbd3: could not create socket\n"); + goto error; + } + + kernel_setsockopt(sock->sock, SOL_SOCKET, SO_SNDTIMEO, (char *)&timeout, sizeof(timeout)); + kernel_setsockopt(sock->sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&timeout, sizeof(timeout)); + sock->sock->sk->sk_allocation = GFP_NOIO; + if (server->host.type == HOST_IP4) { + struct sockaddr_in sin; + memset(&sin, 0, sizeof(sin)); + sin.sin_family = AF_INET; + memcpy(&(sin.sin_addr), server->host.addr, 4); + sin.sin_port = server->host.port; + if (kernel_connect(sock->sock, (struct sockaddr *)&sin, sizeof(sin), 0) != 0) { + printk(KERN_ERR "dnbd3: connection to host failed (ipv4)\n"); + goto error; + } + } else { + struct sockaddr_in6 sin; + memset(&sin, 0, sizeof(sin)); + sin.sin6_family = AF_INET6; + memcpy(&(sin.sin6_addr), server->host.addr, 16); + sin.sin6_port = server->host.port; + if (kernel_connect(sock->sock, (struct sockaddr *)&sin, sizeof(sin), 0) != 0){ + printk(KERN_ERR "dnbd3: connection to host failed (ipv6)\n"); + goto error; + } + } + // Request filesize + dnbd3_request.magic = dnbd3_packet_magic; + dnbd3_request.cmd = CMD_SELECT_IMAGE; + iov[0].iov_base = &dnbd3_request; + iov[0].iov_len = sizeof(dnbd3_request); + serializer_reset_write(&payload_buffer); + serializer_put_uint16(&payload_buffer, PROTOCOL_VERSION); + serializer_put_string(&payload_buffer, dev->imgname); + serializer_put_uint16(&payload_buffer, dev->rid); + serializer_put_uint8(&payload_buffer, 0); // is_server = false + iov[1].iov_base = &payload_buffer; + dnbd3_request.size = iov[1].iov_len = serializer_get_written_length(&payload_buffer); + fixup_request(dnbd3_request); + mlen = sizeof(dnbd3_request) + iov[1].iov_len; + if (kernel_sendmsg(sock->sock, &msg, iov, 2, mlen) != mlen) { + printk(KERN_ERR "dnbd3: could not send CMD_SIZE_REQUEST\n"); + goto error; + } + // receive reply header + iov[0].iov_base = &dnbd3_reply; + iov[0].iov_len = sizeof(dnbd3_reply); + if (kernel_recvmsg(sock->sock, &msg, iov, 1, sizeof(dnbd3_reply), msg.msg_flags) != sizeof(dnbd3_reply)) { + printk(KERN_ERR "dnbd3: received corrupted reply header after CMD_SIZE_REQUEST\n"); + goto error; + } + + // check reply header + fixup_reply(dnbd3_reply); + if (dnbd3_reply.cmd != CMD_SELECT_IMAGE || + dnbd3_reply.size < 3 || + dnbd3_reply.size > MAX_PAYLOAD || + dnbd3_reply.magic != dnbd3_packet_magic) { + printk(KERN_ERR "dnbd3: received invalid reply to CMD_SIZE_REQUEST image does not exist on server\n"); + goto error; + } + + // receive reply payload + iov[0].iov_base = &payload_buffer; + iov[0].iov_len = dnbd3_reply.size; + if (kernel_recvmsg(sock->sock, &msg, iov, 1, dnbd3_reply.size, msg.msg_flags) != dnbd3_reply.size) { + printk(KERN_ERR "dnbd3: could not read CMD_SELECT_IMAGE payload on handshake\n"); + goto error; + } + + // handle/check reply payload + serializer_reset_read(&payload_buffer, dnbd3_reply.size); + server->protocol_version = serializer_get_uint16(&payload_buffer); + if (server->protocol_version < MIN_SUPPORTED_SERVER) { + printk(KERN_ERR "dnbd3: server version is lower than min supported version\n"); + goto error; + } + + name = serializer_get_string(&payload_buffer); + rid = serializer_get_uint16(&payload_buffer); + if (dev->rid != rid && strcmp(name, dev->imgname) != 0) { + printk(KERN_ERR "dnbd3: server offers image '%s', requested '%s'\n", name, dev->imgname); + goto error; + } + + reported_size = serializer_get_uint64(&payload_buffer); + if (dev->reported_size == NULL) { + if (reported_size < 4096) { + printk(KERN_ERR "dnbd3: reported size by server is < 4096\n"); + goto error; + } + dev->reported_size = reported_size; + set_capacity(dev->disk, dev->reported_size >> 9); /* 512 Byte blocks */ + } else if (dev->reported_size != reported_size) { + printk(KERN_ERR "dnbd3: reported size by server is %llu but should be %llu\n", reported_size, dev->reported_size); + } + + printk(KERN_DEBUG "dnbd3: connected to image %s, filesize %llu\n", dev->imgname, dev->reported_size); + +// TODO add heartbeat + // add heartbeat timer and scheduler for the command + INIT_WORK(&sock->keepalive, keepalive); + sock->heartbeat_count = 0; + timer_setup(&sock->keepalive_timer, dnbd3_keepalive, 0); + sock->keepalive_timer.expires = KEEPALIVE_TIMER; + add_timer(&sock->keepalive_timer); + + mutex_unlock(&sock->lock); + + return 0; +error: + if (sock->sock) { + sock_release(sock->sock); + sock->sock = NULL; + } + if (req1) { + kfree(req1); + } + mutex_unlock(&sock->lock); + return result; +} + +static int dnbd3_socket_disconnect(dnbd3_device *dev, dnbd3_sock *sock) +{ + printk(KERN_DEBUG "dnbd3: socket disconnect device %i\n", dev->minor); + mutex_lock(&sock->lock); + + // clear heartbeat timer + del_timer_sync(&sock->keepalive_timer); +// destroy_workqueue(&sock->keepalive); + + if (sock->sock) { + kernel_sock_shutdown(sock->sock, SHUT_RDWR); + } + + // clear socket + if (sock->sock) { + sock_release(sock->sock); + sock->sock = NULL; + } + + mutex_unlock(&sock->lock); + mutex_destroy(&sock->lock); + return 0; +} + +int dnbd3_net_disconnect(struct dnbd3_device *dev) +{ + int i; + int result; + del_timer_sync(&dev->discovery_timer); +// destroy_workqueue(&dev->discovery); + for (i = 0; i < NUMBER_CONNECTIONS; i++) { + if (dev->socks[i].sock) { + if (dnbd3_socket_disconnect(dev, &dev->socks[i])) { + result = -EIO; + } + } + } + return result; +} + + +int dnbd3_net_connect(struct dnbd3_device *dev) { + // TODO decide which socket to connect + int result; + dev->socks[0].server = &dev->initial_server; + if (dnbd3_socket_connect(dev, &dev->socks[0]) == 0) { + printServerList(dev); + + INIT_WORK(&dev->discovery, discovery); + timer_setup(&dev->discovery_timer, dnbd3_discovery, 0); + dev->discovery_timer.expires = DISCOVERY_TIMER; + add_timer(&dev->discovery_timer); + + result = 0; + } else { + printk(KERN_ERR "dnbd3: failed to connect to initial server\n"); + result = -ENOENT; + dev->imgname = NULL; + dev->socks[0].server = NULL; + } + return result; +} + + + -- cgit v1.2.3-55-g7522