/*
* This file is part of the Distributed Network Block Device 3
*
* Copyright(c) 2019 Frederic Robra <frederic@robra.org>
* Parts copyright 2011-2012 Johann Latocha <johann@latocha.de>
*
* This file may be licensed under the terms of of the
* GNU General Public License Version 2 (the ``GPL'').
*
* Software distributed under the License is distributed
* on an ``AS IS'' basis, WITHOUT WARRANTY OF ANY KIND, either
* express or implied. See the GPL for the specific language
* governing rights and limitations.
*
* You should have received a copy of the GPL along with this
* program. If not, go to http://www.gnu.org/licenses/gpl.html
* or write to the Free Software Foundation, Inc.,
* 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*
*/
#include <net/sock.h>
#include <linux/wait.h>
#include <linux/sort.h>
#include "net.h"
#include "utils.h"
#include "clientconfig.h"
#include "mq.h"
#define DNBD3_REQ_OP_SPECIAL REQ_OP_DRV_IN
#define DNBD3_REQ_OP_CONNECT REQ_OP_DRV_OUT
#define dnbd3_cmd_to_op_special(req, cmd) \
(req)->cmd_flags = DNBD3_REQ_OP_SPECIAL | ((cmd) << REQ_FLAG_BITS)
#define dnbd3_op_special_to_cmd(req) \
((req)->cmd_flags >> REQ_FLAG_BITS)
#define dnbd3_connect_to_req(req) \
(req)->cmd_flags = DNBD3_REQ_OP_CONNECT \
| ((CMD_SELECT_IMAGE) << REQ_FLAG_BITS)
#define dnbd3_test_block_to_req(req) \
do { \
(req)->cmd_flags = REQ_OP_READ; \
(req)->__data_len = RTT_BLOCK_SIZE; \
(req)->__sector = 0; \
} while (0)
#define dnbd3_sock_create(af,type,proto,sock) \
sock_create_kern(&init_net, (af) == HOST_IP4 ? AF_INET : AF_INET6, \
type, proto, sock)
#define REQUEST_TIMEOUT \
(HZ * SOCKET_TIMEOUT_CLIENT_DATA)
#define dnbd3_init_msghdr(h) \
do { \
(h).msg_name = NULL; \
(h).msg_namelen = 0; \
(h).msg_control = NULL; \
(h).msg_controllen = 0; \
(h).msg_flags = MSG_WAITALL | MSG_NOSIGNAL; \
} while (0)
static int __dnbd3_socket_connect(struct dnbd3_sock *sock,
struct dnbd3_server * server);
static int dnbd3_socket_connect(struct dnbd3_sock *sock,
struct dnbd3_server * server);
static int dnbd3_socket_disconnect(struct dnbd3_sock *sock);
/*
* Methods for request and receive commands
*/
/**
* dnbd3_to_handle - convert tag and cookie to handle
* @tag: the tag to convert
* @cookie: the cookie to convert
*/
static inline uint64_t dnbd3_to_handle(uint32_t tag, uint32_t cookie) {
return ((uint64_t) tag << 32) | cookie;
}
/**
* dnbd3_tag_from_handle - get tag from handle
* @handle: the handle
*/
static inline uint32_t dnbd3_tag_from_handle(uint64_t handle) {
return (uint32_t)(handle >> 32);
}
/**
* dnbd3_cookie_from_handle - get cookie from handle
* @handle: the handle
*/
static inline uint32_t dnbd3_cookie_from_handle(uint64_t handle) {
return (uint32_t) handle;
}
/**
* dnbd3_send_request - send a request
* @sock: the socket where the request is send
* @req: the request to send
* @cmd: optional - the dnbd3_cmd from mq
*
* the tx_lock of the socket must be held
*/
int dnbd3_send_request(struct dnbd3_sock *sock, struct request *req,
struct dnbd3_cmd *cmd)
{
dnbd3_request_t request;
struct msghdr msg;
struct kvec iov[2];
size_t iov_num = 1;
size_t lng;
int result;
uint32_t tag;
uint64_t handle;
serialized_buffer_t payload_buffer;
sock->pending = req;
dnbd3_init_msghdr(msg);
request.magic = dnbd3_packet_magic;
switch (req_op(req)) {
case REQ_OP_READ:
debug_sock(sock, "request operation read");
request.cmd = CMD_GET_BLOCK;
request.offset = blk_rq_pos(req) << 9; // * 512
request.size = blk_rq_bytes(req);
break;
case DNBD3_REQ_OP_SPECIAL:
debug_sock(sock, "request operation special");
request.cmd = dnbd3_op_special_to_cmd(req);
request.size = 0;
break;
case DNBD3_REQ_OP_CONNECT:
debug_sock(sock, "request operation connect to %s",
sock->device->imgname);
request.cmd = CMD_SELECT_IMAGE;
serializer_reset_write(&payload_buffer);
serializer_put_uint16(&payload_buffer, PROTOCOL_VERSION);
serializer_put_string(&payload_buffer, sock->device->imgname);
serializer_put_uint16(&payload_buffer, sock->device->rid);
serializer_put_uint8(&payload_buffer, 0); // is_server = false
iov[1].iov_base = &payload_buffer;
request.size = serializer_get_written_length(&payload_buffer);
iov[1].iov_len = request.size;
iov_num = 2;
break;
default:
return -EIO;
}
sock->cookie++;
if (cmd != NULL) {
cmd->cookie = sock->cookie;
tag = blk_mq_unique_tag(req);
handle = dnbd3_to_handle(tag, sock->cookie);
} else {
handle = sock->cookie;
}
memcpy(&request.handle, &handle, sizeof(handle));
fixup_request(request);
iov[0].iov_base = &request;
iov[0].iov_len = sizeof(request);
lng = iov_num == 1 ? iov[0].iov_len : iov[0].iov_len + iov[1].iov_len;
result = kernel_sendmsg(sock->sock, &msg, iov, iov_num, lng);
if (result != lng) {
error_sock(sock, "connection to server lost");
if (cmd) {
dnbd3_requeue_cmd(cmd);
}
sock->panic = true;
sock->server->failures++;
goto error;
}
error:
sock->pending = NULL;
return result;
}
/**
* dnbd3_send_request_cmd - send a dndb3 cmd
* @sock: the socket where the request is send
* @dnbd3_cmd: the dnbd3 cmd to send
*/
static int dnbd3_send_request_cmd(struct dnbd3_sock *sock, uint16_t dnbd3_cmd)
{
int result;
struct request *req = kmalloc(sizeof(struct request), GFP_KERNEL);
if (!req) {
error_sock(sock, "kmalloc failed");
result = -EIO;
goto error;
}
switch (dnbd3_cmd) {
case CMD_KEEPALIVE:
case CMD_GET_SERVERS:
dnbd3_cmd_to_op_special(req, dnbd3_cmd);
break;
case CMD_SELECT_IMAGE:
dnbd3_connect_to_req(req);
break;
case CMD_GET_BLOCK:
dnbd3_test_block_to_req(req);
break;
default:
warn_sock(sock, "unsupported command %d", dnbd3_cmd);
result = -EINVAL;
goto error;
}
mutex_lock(&sock->tx_lock);
sock->pending = req;
result = dnbd3_send_request(sock, req, NULL);
mutex_unlock(&sock->tx_lock);
error:
if (req) {
kfree(req);
}
return result;
}
/**
* dnbd3_receive_cmd - receive a command
* @sock: the socket where the request is received
* @reply: an unused reply will be filled with the reply of the server
*
* this method should be called directly after the dnbd3_send_request_ method
*/
static int dnbd3_receive_cmd(struct dnbd3_sock *sock, dnbd3_reply_t *reply)
{
int result;
struct msghdr msg;
struct kvec iov;
dnbd3_init_msghdr(msg);
iov.iov_base = reply;
iov.iov_len = sizeof(dnbd3_reply_t);
result = kernel_recvmsg(sock->sock, &msg, &iov, 1, iov.iov_len,
msg.msg_flags);
if (result <= 0) {
return result;
}
fixup_reply(dnbd3_reply);
if (reply->magic != dnbd3_packet_magic) {
error_sock(sock, "receive cmd wrong magic packet");
return -EIO;
}
if (reply->cmd == 0) {
error_sock(sock, "receive command was 0");
return -EIO;
}
return result;
}
/**
* dnbd3_receive_cmd_get_block_mq - receive a block for mq
* @sock: the socket where the request is received
* @reply: the reply initialized by dnbd3_receive_cmd
*
* this method should be called directly after the dnbd3_receive_cmd method
*
* this method copies the data to user space according to the request which is
* encoded in the handle by the send request method and decoded here.
*/
static int dnbd3_receive_cmd_get_block_mq(struct dnbd3_sock *sock,
dnbd3_reply_t *reply)
{
struct dnbd3_cmd *cmd;
struct msghdr msg;
struct request *req = NULL;
struct kvec iov;
struct req_iterator iter;
struct bio_vec bvec_inst;
struct dnbd3_device *dev = sock->device;
struct bio_vec *bvec = &bvec_inst;
sigset_t blocked, oldset;
void *kaddr;
uint32_t tag, cookie;
uint16_t hwq;
uint32_t remaining = reply->size;
int result = 0;
uint64_t handle;
dnbd3_init_msghdr(msg);
memcpy(&handle, &reply->handle, sizeof(handle));
cookie = dnbd3_cookie_from_handle(handle);
tag = dnbd3_tag_from_handle(handle);
hwq = blk_mq_unique_tag_to_hwq(tag);
if (hwq < dev->tag_set.nr_hw_queues) {
req = blk_mq_tag_to_rq(dev->tag_set.tags[hwq],
blk_mq_unique_tag_to_tag(tag));
}
if (!req || !blk_mq_request_started(req)) {
error_sock(sock, "unexpected reply (%d) %p", tag, req);
if (req) {
debug_sock(sock, "requeue request");
dnbd3_requeue_cmd(blk_mq_rq_to_pdu(req));
}
// return -EIO;
goto clear_socket;
}
cmd = blk_mq_rq_to_pdu(req);
mutex_lock(&cmd->lock);
if (cmd->cookie != cookie) {
error_sock(sock, "double reply on req %p, cookie %u, handle cookie %u",
req, cmd->cookie, cookie);
mutex_unlock(&cmd->lock);
// return -EIO;
goto clear_socket;
}
rq_for_each_segment(bvec_inst, req, iter) {
siginitsetinv(&blocked, sigmask(SIGKILL));
sigprocmask(SIG_SETMASK, &blocked, &oldset);
kaddr = kmap(bvec->bv_page) + bvec->bv_offset;
iov.iov_base = kaddr;
iov.iov_len = bvec->bv_len;
result = kernel_recvmsg(sock->sock, &msg, &iov, 1, bvec->bv_len,
msg.msg_flags);
remaining -= result;
if (result != bvec->bv_len) {
kunmap(bvec->bv_page);
sigprocmask(SIG_SETMASK, &oldset, NULL );
error_sock(sock, "could not receive from net to block layer");
dnbd3_requeue_cmd(cmd);
mutex_unlock(&cmd->lock);
if (result >= 0) {
goto clear_socket;
} else {
return result;
}
}
kunmap(bvec->bv_page);
sigprocmask(SIG_SETMASK, &oldset, NULL );
}
mutex_unlock(&cmd->lock);
dnbd3_end_cmd(cmd, 0);
return result;
clear_socket:
warn_sock(sock, "caught an error while receiving block, clearing buffer");
char *buf = kmalloc(RTT_BLOCK_SIZE, GFP_KERNEL);
if (!buf) {
error_sock(sock, "kmalloc failed");
return -EIO;
}
iov.iov_base = buf;
iov.iov_len = RTT_BLOCK_SIZE;
while (remaining > 0) {
result = kernel_recvmsg(sock->sock, &msg, &iov, 1, iov.iov_len,
msg.msg_flags);
if (result <= 0) {
goto error;
}
remaining -= result;
}
debug_sock(sock, "cleared buffer %d bytes, reply size is %d", result,
reply->size);
error:
if (buf) {
kfree(buf);
}
return -EIO;
}
/**
* dnbd3_receive_cmd_get_block_test - receive a test block
* @sock: the socket where the request is received
* @reply: the reply initialized by dnbd3_receive_cmd
*
* this method should be called directly after the dnbd3_receive_cmd method
*
* the received data is just thrown away
*/
static int dnbd3_receive_cmd_get_block_test(struct dnbd3_sock *sock,
dnbd3_reply_t *reply)
{
struct msghdr msg;
struct kvec iov;
int result = 0;
char *buf = kmalloc(reply->size, GFP_KERNEL);
if (!buf) {
error_sock(sock, "kmalloc failed");
goto error;
}
dnbd3_init_msghdr(msg);
iov.iov_base = buf;
iov.iov_len = reply->size;
result = kernel_recvmsg(sock->sock, &msg, &iov, 1, reply->size,
msg.msg_flags);
if (result != RTT_BLOCK_SIZE) {
error_sock(sock, "receive test block failed");
goto error;
}
error:
if (buf) {
kfree(buf);
}
return result;
}
/**
* dnbd3_receive_cmd_get_servers - receive new servers
* @sock: the socket where the request is received
* @reply: the reply initialized by dnbd3_receive_cmd
*
* this method should be called directly after the dnbd3_receive_cmd method
*
* the new servers are copied to dnbd3_device.new_servers and
* dnbd3_device.new_server_num is set accordingly
*/
static int dnbd3_receive_cmd_get_servers(struct dnbd3_sock *sock,
dnbd3_reply_t *reply)
{
struct msghdr msg;
struct kvec iov;
struct dnbd3_device *dev = sock->device;
int result = 1;
int count, remaining;
dnbd3_init_msghdr(msg);
debug_sock(sock, "get servers received");
mutex_lock(&dev->device_lock);
if (!dev->use_server_provided_alts) {
remaining = reply->size;
goto consume_payload;
}
dev->new_servers_num = 0;
count = MIN(NUMBER_SERVERS, reply->size / sizeof(dnbd3_server_entry_t));
if (count != 0) {
iov.iov_base = dev->new_servers;
iov.iov_len = count * sizeof(dnbd3_server_entry_t);
result = kernel_recvmsg(sock->sock, &msg, &iov, 1, iov.iov_len,
msg.msg_flags);
if (result <= 0) {
error_sock(sock, "failed to receive get servers %d",
result);
mutex_unlock(&dev->device_lock);
return result;
} else if (result != (count * sizeof(dnbd3_server_entry_t))) {
error_sock(sock, "failed to get servers");
mutex_unlock(&dev->device_lock);
return -EIO;
}
dev->new_servers_num = count;
}
/*
* if there were more servers than accepted, remove the remaining data
* from the socket buffer
* abuse the reply struct as the receive buffer
*/
remaining = reply->size - (count * sizeof(dnbd3_server_entry_t));
consume_payload:
while (remaining > 0) {
count = MIN(sizeof(dnbd3_reply_t), remaining);
iov.iov_base = reply;
iov.iov_len = count;
result = kernel_recvmsg(sock->sock, &msg, &iov, 1, count,
msg.msg_flags);
if (result <= 0) {
error_sock(sock, "failed to receive payload from get servers");
mutex_unlock(&dev->device_lock);
return result;
}
remaining -= result;
}
mutex_unlock(&dev->device_lock);
return result;
}
/**
* dnbd3_receive_cmd_latest_rid - receive latest rid
* @sock: the socket where the request is received
* @reply: the reply initialized by dnbd3_receive_cmd
*
* this method should be called directly after the dnbd3_receive_cmd method
*
* dnbd3_device.update_available is set if a new RID is received
*/
static int dnbd3_receive_cmd_latest_rid(struct dnbd3_sock *sock,
dnbd3_reply_t *reply)
{
struct kvec iov;
uint16_t rid;
int result;
struct msghdr msg;
struct dnbd3_device *dev = sock->device;
dnbd3_init_msghdr(msg);
debug_sock(sock, "latest rid received");
if (reply->size != 2) {
error_sock(sock, "failed to get latest rid, wrong size");
return -EIO;
}
iov.iov_base = &rid;
iov.iov_len = sizeof(rid);
result = kernel_recvmsg(sock->sock, &msg, &iov, 1, iov.iov_len,
msg.msg_flags);
if (result <= 0) {
error_sock(sock, "failed to receive latest rid");
return result;
}
rid = net_order_16(rid);
debug_sock(sock, "latest rid of %s is %d (currently using %d)",
dev->imgname, (int)rid, (int)dev->rid);
dev->update_available = (rid > dev->rid ? true : false);
return result;
}
/**
* dnbd3_receive_cmd_latest_rid - select the image
* @sock: the socket where the request is received
* @reply: the reply initialized by dnbd3_receive_cmd
*
* this method should be called directly after the dnbd3_receive_cmd method
*
* if this is the first connection the image name, file size and rid will be set
* if this is a further connection image name, file size and rid will be checked
*/
static int dnbd3_receive_cmd_select_image(struct dnbd3_sock *sock,
dnbd3_reply_t *reply)
{
struct kvec iov;
uint16_t rid;
char *name;
int result;
struct msghdr msg;
serialized_buffer_t payload_buffer;
uint64_t reported_size;
struct dnbd3_device *dev = sock->device;
dnbd3_init_msghdr(msg);
debug_sock(sock, "select image received");
iov.iov_base = &payload_buffer;
iov.iov_len = reply->size;
result = kernel_recvmsg(sock->sock, &msg, &iov, 1, iov.iov_len,
msg.msg_flags);
if (result <= 0) {
error_sock(sock, "failed to receive select image %d", result);
return result;
} else if (result != reply->size) {
error_sock(sock, "could not read CMD_SELECT_IMAGE payload on handshake, size is %d and should be %d",
result, reply->size);
return -EIO;
}
/* handle/check reply payload */
serializer_reset_read(&payload_buffer, reply->size);
sock->server->protocol_version = serializer_get_uint16(&payload_buffer);
if (sock->server->protocol_version < MIN_SUPPORTED_SERVER) {
error_sock(sock, "server version is lower than min supported version");
return -EIO;
}
//TODO compare RID
name = serializer_get_string(&payload_buffer);
rid = serializer_get_uint16(&payload_buffer);
if (dev->rid != rid && strcmp(name, dev->imgname) != 0) {
error_sock(sock, "server offers image '%s', requested '%s'",
name, dev->imgname);
return -EIO;
}
reported_size = serializer_get_uint64(&payload_buffer);
if (!dev->reported_size) {
if (reported_size < 4096) {
error_sock(sock, "reported size by server is < 4096");
return -EIO;
}
dev->reported_size = reported_size;
set_capacity(dev->disk, dev->reported_size >> 9); /* 512 Byte */
} else if (dev->reported_size != reported_size) {
error_sock(sock, "reported size by server is %llu but should be %llu",
reported_size, dev->reported_size);
return -EIO;
}
return result;
}
/*
* Timer and workers
*/
/**
* dnbd3_timer - the timer to start different workers
* @arg: the timer_list used to get the dnbd3_device
*
* workers to start:
* - panic_worker
* - keepalive_worker for each connected socket
* - discovery_worker
*/
static void dnbd3_timer(struct timer_list *arg)
{
struct dnbd3_device *dev;
int i;
dev = container_of(arg, struct dnbd3_device, timer);
queue_work(dnbd3_wq, &dev->panic_worker);
if (!dnbd3_is_mq_busy(dev)) {
if (dev->timer_count % TIMER_INTERVAL_KEEPALIVE_PACKET == 0) {
for (i = 0; i < NUMBER_CONNECTIONS; i++) {
if (dnbd3_is_sock_alive(dev->socks[i])) {
queue_work(dnbd3_wq, &dev->socks[i].keepalive_worker);
}
}
}
/* start after 4 seconds */
if (dev->timer_count % TIMER_INTERVAL_PROBE_NORMAL == 4) {
queue_work(dnbd3_wq, &dev->discovery_worker);
}
dev->timer_count++;
}
dev->timer.expires = jiffies + HZ;
add_timer(&dev->timer);
}
/**
* dnbd3_receive_worker - receives data from a socket
* @work: the work used to get the dndb3_sock
*
* receives data until the socket is closed (returns 0)
*/
static void dnbd3_receive_worker(struct work_struct *work)
{
struct dnbd3_sock *sock;
dnbd3_reply_t reply;
int result;
sock = container_of(work, struct dnbd3_sock, receive_worker);
debug_sock(sock, "receive worker is starting");
while(1) { // loop until socket returns 0
result = dnbd3_receive_cmd(sock, &reply);
if (result == -EAGAIN) {
continue;
} else if (result <= 0) {
error_sock(sock, "connection to server lost %d", result);
goto error;
}
switch (reply.cmd) {
case CMD_GET_BLOCK:
result = dnbd3_receive_cmd_get_block_mq(sock, &reply);
if (result <= 0) {
error_sock(sock, "receive cmd get block mq failed %d",
result);
goto error;
}
continue;
case CMD_GET_SERVERS:
result = dnbd3_receive_cmd_get_servers(sock, &reply);
if (result <= 0) {
error_sock(sock, "receive cmd get servers failed %d",
result);
goto error;
}
break;
case CMD_LATEST_RID:
result = dnbd3_receive_cmd_latest_rid(sock, &reply);
if (result <= 0) {
error_sock(sock, "receive cmd latest rid failed %d",
result);
goto error;
}
break;
case CMD_KEEPALIVE:
if (reply.size != 0) {
error_sock(sock, "got keep alive packet with payload");
goto error;
}
debug_sock(sock, "keep alive received");
break;
case CMD_SELECT_IMAGE:
result = dnbd3_receive_cmd_select_image(sock, &reply);
if (result <= 0) {
error_sock(sock, "receive cmd select image failed %d",
result);
goto error;
}
break;
default:
warn_sock(sock, "unknown command received");
break;
}
error:
if (result == 0) {
info_sock(sock, "result is 0, socket seems to be down");
sock->panic = true;
break;
} else if (result < 0) {
/* discovery takes care of to many failures */
sock->server->failures++;
warn_sock(sock, "receive error happened %d, total failures %d",
result, sock->server->failures);
}
debug_sock(sock, "receive completed, waiting for next receive");
}
debug_sock(sock, "receive work queue is stopped");
}
/**
* dnbd3_receive_worker - sends a keepalive
* @work: the work used to get the dndb3_sock
*/
static void dnbd3_keepalive_worker(struct work_struct *work)
{
struct dnbd3_sock *sock;
sock = container_of(work, struct dnbd3_sock, keepalive_worker);
debug_sock(sock, "starting keepalive worker");
dnbd3_send_request_cmd(sock, CMD_KEEPALIVE);
}
/**
* dnbd3_compare_servers - comparator for the server
* @lhs: left hand sign
* @rhs: right hand sign
*/
static int dnbd3_compare_servers(const void *lhs, const void *rhs) {
uint64_t l, r;
struct dnbd3_server *lhs_server = *((struct dnbd3_server **) lhs);
struct dnbd3_server *rhs_server = *((struct dnbd3_server **) rhs);
l = lhs_server->host.type != 0 ? lhs_server->avg_rtt
: RTT_UNREACHABLE + 1;
r = rhs_server->host.type != 0 ? rhs_server->avg_rtt
: RTT_UNREACHABLE + 1;
if (l < r) {
return -1;
} else if (l > r) {
return 1;
}
return 0;
}
/**
* dnbd3_sort_server - sort the alt server according to their avg rtt
* @dev: the dndb3 device
*
* the returned array has to be freed with kfree
*/
static struct dnbd3_server **dnbd3_sort_server(struct dnbd3_device *dev) {
int i;
struct dnbd3_server **sorted_servers = kmalloc(NUMBER_SERVERS * sizeof(struct dnbd3_device *), GFP_KERNEL);
if (!sorted_servers) {
return NULL;
}
for (i = 0; i < NUMBER_SERVERS; i++) {
sorted_servers[i] = &dev->alt_servers[i];
}
sort(sorted_servers, NUMBER_SERVERS, sizeof(struct dnbd3_device *),
&dnbd3_compare_servers, NULL);
return sorted_servers;
}
static int dnbd3_panic_connect(struct dnbd3_device *dev)
{
struct dnbd3_server *working = NULL;
int i;
for (i = 0; i < NUMBER_CONNECTIONS; i++) {
if (dnbd3_is_sock_alive(dev->socks[i])) {
working = dev->socks[i].server;
}
}
if (working == NULL) {
for (i = 0; i < NUMBER_SERVERS; i++) {
if (!dnbd3_socket_connect(&dev->socks[0],
&dev->alt_servers[i])) {
working = &dev->alt_servers[i];
}
}
}
if (working == NULL) {
return -EIO;
}
for (i = 0; i < NUMBER_CONNECTIONS; i++) {
if (dev->socks[i].server != working) {
dnbd3_socket_connect(&dev->socks[i], working);
}
}
return 0;
}
/**
* dnbd3_adjust_connections - create a connection plan and connect
* @dev: the dnbd3 device
*
* 1. sort the alt server after the avg rtt
* 2. create a connection plan
* 3. connect the plan
*/
static int dnbd3_adjust_connections(struct dnbd3_device *dev) {
int i, j, fallback;
struct dnbd3_server *plan[NUMBER_CONNECTIONS];
struct dnbd3_server **servers = dnbd3_sort_server(dev);
//TODO don't connect to anyting bader then rtt unknown
if (servers && servers[0]->host.type != 0) {
plan[0] = servers[0];
fallback = 0;
j = 1;
debug_dev(dev, "connection plan:");
debug_server(dev, plan[0], "server 0 with rtt %llu:",
plan[0]->avg_rtt);
for (i = 1; i < NUMBER_CONNECTIONS; i++) {
if (servers[j]->host.type != 0 &&
servers[j]->avg_rtt < RTT_UNKNOWN) {
if (RTT_FACTOR(plan[i - 1]->avg_rtt) >
servers[j]->avg_rtt) {
plan[i] = servers[j];
j++;
} else {
plan[i] = plan[fallback];
fallback++;
}
} else {
plan[i] = plan[fallback];
fallback++;
}
debug_server(dev, plan[i], "server %d with rtt %llu:",
i, plan[i]->avg_rtt);
}
kfree(servers);
for (i = 0; i < NUMBER_CONNECTIONS; i++) {
if (plan[i] != dev->socks[i].server) {
if (dnbd3_is_sock_alive(dev->socks[i])) {
dnbd3_socket_disconnect(&dev->socks[i]);
}
j = dnbd3_socket_connect(&dev->socks[i], plan[i]);
if (j) {
return j;
}
}
}
return 0;
} else { /* there is nothing to connect */
if (servers) {
kfree(servers);
}
return -ENONET;
}
}
/**
* dnbd3_panic_worker - handle panicked sockets
* @work: the work used to get the dndb3_device
*
* 1. disconnect panicked socket
* 2. reconnect to good alternative
* 3. if no socket is connected do a panic_connect
*/
static void dnbd3_panic_worker(struct work_struct *work)
{
struct dnbd3_device *dev;
bool panic = false;
int i;
int sock_alive = 0;
dev = container_of(work, struct dnbd3_device, panic_worker);
for (i = 0; i < NUMBER_CONNECTIONS; i++) {
if (dev->socks[i].panic) {
panic = true;
dnbd3_set_rtt_unreachable(dev->socks[i].server);
dnbd3_socket_disconnect(&dev->socks[i]);
} else if (dnbd3_is_sock_alive(dev->socks[i])) {
sock_alive++;
}
}
if (panic) {
warn_dev(dev, "panicked, connections still alive %d",
sock_alive);
mutex_lock(&dev->device_lock);
if (dnbd3_adjust_connections(dev)) {
if (dnbd3_panic_connect(dev)) {
error_dev(dev, "failed to connect to any server");
dev->connected = false;
}
}
mutex_unlock(&dev->device_lock);
}
}
/**
* dnbd3_meassure_rtt - meassure the rtt of a server
* @dev: the device this server belongs to
* @server: the server to meassure
*/
static int dnbd3_meassure_rtt(struct dnbd3_device *dev,
struct dnbd3_server *server)
{
struct timeval start, end;
dnbd3_reply_t reply;
struct request req;
int result;
uint64_t rtt = RTT_UNREACHABLE;
struct dnbd3_sock sock = {
.sock_nr = NUMBER_CONNECTIONS,
.sock = NULL,
.device = dev,
.server = server
};
result = __dnbd3_socket_connect(&sock, server);
if (result) {
error_sock(&sock, "socket connect failed in rtt measurement");
goto error;
}
dnbd3_connect_to_req(&req);
result = dnbd3_send_request_cmd(&sock, CMD_SELECT_IMAGE);
if (result <= 0) {
error_sock(&sock, "request select image failed in rtt measurement");
goto error;
}
result = dnbd3_receive_cmd(&sock, &reply);
if (result <= 0) {
error_sock(&sock, "receive select image failed in rtt measurement");
goto error;
}
if (reply.magic != dnbd3_packet_magic || reply.cmd != CMD_SELECT_IMAGE
|| reply.size < 4) {
error_sock(&sock, "receive select image wrong header in rtt measurement");
result = -EIO;
goto error;
}
result = dnbd3_receive_cmd_select_image(&sock, &reply);
if (result <= 0) {
error_sock(&sock, "receive data select image failed in rtt measurement");
goto error;
}
do_gettimeofday(&start);
result = dnbd3_send_request_cmd(&sock, CMD_GET_BLOCK);
if (result <= 0) {
error_sock(&sock, "request test block failed in rtt measurement");
goto error;
}
result = dnbd3_receive_cmd(&sock, &reply);
if (reply.magic != dnbd3_packet_magic|| reply.cmd != CMD_GET_BLOCK
|| reply.size != RTT_BLOCK_SIZE) {
error_sock(&sock, "receive header cmd test block failed in rtt measurement");
result = -EIO;
goto error;
}
result = dnbd3_receive_cmd_get_block_test(&sock, &reply);
if (result <= 0) {
error_sock(&sock, "receive test block failed in rtt measurement");
goto error;
}
do_gettimeofday(&end); // end rtt measurement
rtt = (uint64_t)((end.tv_sec - start.tv_sec) * 1000000ull
+ (end.tv_usec - start.tv_usec));
debug_sock(&sock, "new rrt is %llu", rtt);
error:
sock.server->rtts[dev->discovery_count % 4] = rtt;
sock.server->avg_rtt = dnbd3_avg_rtt(sock.server);
if (result <= 0) {
server->failures++;
}
if (sock.sock) {
kernel_sock_shutdown(sock.sock, SHUT_RDWR);
sock.server = NULL;
sock_release(sock.sock);
sock.sock = NULL;
}
return result;
}
/**
* dnbd3_merge_new_server - merge the new server into the alt server list
* @dev: the device
* @new_server: the new server list to merge
*/
static void dnbd3_merge_new_server(struct dnbd3_device *dev,
dnbd3_server_entry_t *new_server)
{
int i;
struct dnbd3_server *existing_server, *free_server, *failed_server;
existing_server = NULL;
free_server = NULL;
failed_server = NULL;
/* find servers in alternative servers */
for (i = 0; i < NUMBER_SERVERS; i++) {
if ((new_server->host.type == dev->alt_servers[i].host.type)
&& (new_server->host.port == dev->alt_servers[i].host.port)
&& (0 == memcmp(new_server->host.addr,
dev->alt_servers[i].host.addr,
(new_server->host.type == HOST_IP4 ? 4 : 16)
))) {
existing_server = &dev->alt_servers[i];
} else if (dev->alt_servers[i].host.type == 0) {
free_server = &dev->alt_servers[i];
} else if (dev->alt_servers[i].failures > 20) {
failed_server = &dev->alt_servers[i];
}
}
if (existing_server) {
if (new_server->failures == 1) { /* remove is requested */
info_server(dev, new_server,
"remove server is requested");
// adjust connection will remove it later
existing_server->host.type = 0;
dnbd3_set_rtt_unreachable(existing_server);
}
// existing_server->failures = 0; // reset failure count
return;
} else if (free_server) {
//TODO disconnect the server if it is connected
free_server->host = new_server->host;
} else if (failed_server) {
failed_server->host = new_server->host;
free_server = failed_server;
} else {
/* no server found to replace */
return;
}
info_server(dev, free_server, "got new alternative server");
free_server->failures = 0;
free_server->protocol_version = 0;
dnbd3_set_rtt_unknown(free_server);
}
/**
* dnbd3_discovery_worker - handle discovery
* @work: the work used to get the dndb3_device
*
* 1. check if new servers are available and set them to alternative servers
* 2. meassure the rtt for all available servers
* 3. adjust the connections
*/
static void dnbd3_discovery_worker(struct work_struct *work)
{
struct dnbd3_device *dev;
int i;
struct dnbd3_server *server;
dnbd3_server_entry_t *new_server;
dev = container_of(work, struct dnbd3_device, discovery_worker);
debug_dev(dev, "starting discovery worker new server num is %d",
dev->new_servers_num);
if (dev->new_servers_num) {
mutex_lock(&dev->device_lock);
for (i = 0; i < dev->new_servers_num; i++) {
new_server = &dev->new_servers[i];
if (new_server->host.type != 0) {
dnbd3_merge_new_server(dev, new_server);
}
}
dev->new_servers_num = 0;
mutex_unlock(&dev->device_lock);
}
// measure rtt for all alt servers
for (i = 0; i < NUMBER_SERVERS; i++) {
server = &dev->alt_servers[i];
if (server->host.type) {
if (dnbd3_meassure_rtt(dev, server) <= 0) {
server->failures++;
warn_server(dev, server,
"failed to meassure rtt");
}
}
}
mutex_lock(&dev->device_lock);
if (dnbd3_adjust_connections(dev)) {
if (dnbd3_panic_connect(dev)) {
error_dev(dev, "failed to connect to any server");
dev->connected = false;
}
}
mutex_unlock(&dev->device_lock);
dev->discovery_count++;
}
/*
* Connect and disconnect
*/
/**
* __dnbd3_socket_connect - internal connect a socket to a server
* @sock: the socket to connect
* @server: the server
*/
static int __dnbd3_socket_connect(struct dnbd3_sock *sock,
struct dnbd3_server *server)
{
int result = 0;
struct timeval timeout;
if (server->host.port == 0 || server->host.type == 0) {
error_sock(sock, "host or port not set");
return -EIO;
}
if (sock->sock) {
warn_sock(sock, "already connected");
return -EIO;
}
timeout.tv_sec = SOCKET_TIMEOUT_CLIENT_DATA;
timeout.tv_usec = 0;
result = dnbd3_sock_create(server->host.type, SOCK_STREAM, IPPROTO_TCP,
&sock->sock);
if (result < 0) {
error_sock(sock, "could not create socket");
goto error;
}
kernel_setsockopt(sock->sock, SOL_SOCKET, SO_SNDTIMEO, (char *)&timeout,
sizeof(timeout));
kernel_setsockopt(sock->sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&timeout,
sizeof(timeout));
sock->sock->sk->sk_allocation = GFP_NOIO;
if (server->host.type == HOST_IP4) {
struct sockaddr_in sin;
memset(&sin, 0, sizeof(sin));
sin.sin_family = AF_INET;
memcpy(&(sin.sin_addr), server->host.addr, 4);
sin.sin_port = server->host.port;
result = kernel_connect(sock->sock, (struct sockaddr *)&sin,
sizeof(sin), 0);
if (result != 0) {
error_sock(sock, "connection to host failed");
goto error;
}
} else {
struct sockaddr_in6 sin;
memset(&sin, 0, sizeof(sin));
sin.sin6_family = AF_INET6;
memcpy(&(sin.sin6_addr), server->host.addr, 16);
sin.sin6_port = server->host.port;
result = kernel_connect(sock->sock, (struct sockaddr *)&sin,
sizeof(sin), 0);
if (result != 0){
error_sock(sock, "connection to host failed");
goto error;
}
}
return 0;
error:
if (sock->sock) {
sock_release(sock->sock);
sock->sock = NULL;
}
return result;
}
/**
* dnbd3_socket_connect - connect a socket to a server
* @sock: the socket
* @server: the server to connect
*
* 1. connects the server to the socket
* 2. select the image
* 3. start receiver_worker and keepalive_worker
*/
static int dnbd3_socket_connect(struct dnbd3_sock *sock,
struct dnbd3_server *server)
{
int result = -EIO;
dnbd3_reply_t reply;
struct dnbd3_device *dev = sock->device;
sock->server = server;
debug_sock(sock, "socket connect");
mutex_init(&sock->tx_lock);
mutex_lock(&sock->tx_lock);
result = __dnbd3_socket_connect(sock, server);
if (result) {
error_sock(sock, "connection to socket failed");
mutex_unlock(&sock->tx_lock);
result = -EIO;
goto error;
}
mutex_unlock(&sock->tx_lock);
sock->panic = false;
if (!sock->sock) {
error_sock(sock, "socket is not connected");
server->failures++;
result = -EIO;
goto error;
}
result = dnbd3_send_request_cmd(sock, CMD_SELECT_IMAGE);
if (result <= 0) {
error_sock(sock, "connection to image %s failed", dev->imgname);
result = -EIO;
goto error;
}
result = dnbd3_receive_cmd(sock, &reply);
if (result <= 0) {
error_sock(sock, "receive cmd to image %s failed",
dev->imgname);
result = -EIO;
goto error;
}
if (reply.magic != dnbd3_packet_magic || reply.cmd != CMD_SELECT_IMAGE
|| reply.size < 4) {
error_sock(sock, "receive select image wrong header %s",
dev->imgname);
result = -EIO;
goto error;
}
result = dnbd3_receive_cmd_select_image(sock, &reply);
if (result <= 0) {
error_sock(sock, "receive cmd select image %s failed",
dev->imgname);
result = -EIO;
goto error;
}
debug_sock(sock, "connected to image %s, filesize %llu", dev->imgname,
dev->reported_size);
// start the receiver
INIT_WORK(&sock->receive_worker, dnbd3_receive_worker);
queue_work(dnbd3_wq, &sock->receive_worker);
INIT_WORK(&sock->keepalive_worker, dnbd3_keepalive_worker);
/* TODO not on every connect? request alternative servers receiver will handle this */
if (dnbd3_send_request_cmd(sock, CMD_GET_SERVERS) <= 0) {
error_sock(sock, "failed to get servers in discovery");
}
return 0;
error:
server->failures++;
sock->panic = true;
if (sock->sock) {
kernel_sock_shutdown(sock->sock, SHUT_RDWR);
cancel_work_sync(&sock->receive_worker);
sock_release(sock->sock);
sock->sock = NULL;
}
mutex_destroy(&sock->tx_lock);
return result;
}
/**
* dnbd3_socket_disconnect - disconnect a socket
* @sock: the socket to disconnect
*/
static int dnbd3_socket_disconnect(struct dnbd3_sock *sock)
{
cancel_work_sync(&sock->keepalive_worker);
debug_sock(sock, "socket disconnect");
mutex_lock(&sock->tx_lock);
/*
* Important sequence to shut down socket
* 1. kernel_sock_shutdown
* socket shutdown, receiver which block ins socket receive
* returns 0
* 2. cancel_work_sync(receiver)
* wait for the receiver to finish, so the socket is not used
* anymore
* 3. sock_release
* release the socket and set to NULL
*/
if (sock->sock) {
kernel_sock_shutdown(sock->sock, SHUT_RDWR);
}
mutex_unlock(&sock->tx_lock);
mutex_destroy(&sock->tx_lock);
cancel_work_sync(&sock->receive_worker);
if (sock->sock) {
sock_release(sock->sock);
sock->sock = NULL;
}
sock->server = NULL;
sock->panic = false;
return 0;
}
/**
* dnbd3_net_connect - connect device
* @dev: the device to connect
*
* dnbd3_device.alt_servers[0] must be set
*/
int dnbd3_net_connect(struct dnbd3_device *dev)
{
int result;
debug_dev(dev, "connecting to server");
if (dev->alt_servers[0].host.type == 0) {
return -ENONET;
}
result = dnbd3_adjust_connections(dev);
if (result) {
error_dev(dev, "failed to connect to initial server");
dnbd3_net_disconnect(dev);
return -ENOENT;
}
dev->connected = true;
debug_dev(dev, "connected, starting workers");
INIT_WORK(&dev->discovery_worker, dnbd3_discovery_worker);
INIT_WORK(&dev->panic_worker, dnbd3_panic_worker);
timer_setup(&dev->timer, dnbd3_timer, 0);
dev->timer.expires = jiffies + HZ;
add_timer(&dev->timer);
// alt_server[0] is the initial server
// result = dnbd3_server_connect(dev, &dev->alt_servers[0]);
// if (result) {
// error_dev(dev, "failed to connect to initial server");
// result = -ENOENT;
// dev->imgname = NULL;
// dev->socks[0].server = NULL;
// }
return result;
}
/**
* dnbd3_net_disconnect - disconnect device
* @dev: the device to disconnect
*/
int dnbd3_net_disconnect(struct dnbd3_device *dev)
{
int i;
int result = 0;
del_timer_sync(&dev->timer);
/* be sure it does not recover while disconnecting */
cancel_work_sync(&dev->discovery_worker);
cancel_work_sync(&dev->panic_worker);
for (i = 0; i < NUMBER_CONNECTIONS; i++) {
if (dev->socks[i].sock) {
if (dnbd3_socket_disconnect(&dev->socks[i])) {
result = -EIO;
}
}
}
dev->connected = false;
return result;
}