/*
* This file is part of the Distributed Network Block Device 3
*
* Copyright(c) 2019 Frederic Robra <frederic@robra.org>
* Parts copyright 2011-2012 Johann Latocha <johann@latocha.de>
*
* This file may be licensed under the terms of of the
* GNU General Public License Version 2 (the ``GPL'').
*
* Software distributed under the License is distributed
* on an ``AS IS'' basis, WITHOUT WARRANTY OF ANY KIND, either
* express or implied. See the GPL for the specific language
* governing rights and limitations.
*
* You should have received a copy of the GPL along with this
* program. If not, go to http://www.gnu.org/licenses/gpl.html
* or write to the Free Software Foundation, Inc.,
* 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*
*/
#include <net/sock.h>
#include <linux/wait.h>
#include <linux/sort.h>
#include "net.h"
#include "net-txrx.h"
#include "utils.h"
#include "clientconfig.h"
#include "mq.h"
#define dnbd3_sock_create(af,type,proto,sock) \
sock_create_kern(&init_net, (af) == HOST_IP4 ? AF_INET : AF_INET6, \
type, proto, sock)
#define dnbd3_sock_release(sock) \
do { \
sock_release((sock)->sock); \
sock->sock = NULL; \
}while (0)
#define dnbd3_is_sock_alive(s) ((s).sock && (s).server)
static int __dnbd3_socket_connect(struct dnbd3_sock *sock,
struct dnbd3_server *server);
static int dnbd3_socket_connect(struct dnbd3_sock *sock,
struct dnbd3_server *server);
static int dnbd3_socket_disconnect(struct dnbd3_sock *sock);
/*
* Timer and workers
*/
/**
* dnbd3_timer - the timer to start different workers
* @arg: the timer_list used to get the dnbd3_device
*
* workers to start:
* - panic_worker
* - keepalive_worker for each connected socket
* - discovery_worker
*/
static void dnbd3_timer(struct timer_list *arg)
{
struct dnbd3_device *dev;
unsigned long busy;
int i;
dev = container_of(arg, struct dnbd3_device, timer);
queue_work(dnbd3_wq, &dev->panic_worker);
busy = dnbd3_is_mq_busy(dev);
if (dev->timer_count % TIMER_INTERVAL_KEEPALIVE_PACKET == 0) {
for (i = 0; i < dev->number_connections; i++) {
if (!test_bit(i, &busy) && dnbd3_is_sock_alive(dev->socks[i])) {
queue_work(dnbd3_wq, &dev->socks[i].keepalive_worker);
}
}
}
/* start after 2 seconds */
if (dev->timer_count % TIMER_INTERVAL_PROBE_NORMAL == 2) {
queue_work(dnbd3_wq, &dev->discovery_worker);
}
dev->timer_count++;
dev->timer.expires = jiffies + HZ;
add_timer(&dev->timer);
}
/**
* dnbd3_receive_worker - receives data from a socket
* @work: the work used to get the dndb3_sock
*
* receives data until the socket is closed (returns 0)
*/
static void dnbd3_receive_worker(struct work_struct *work)
{
struct dnbd3_sock *sock;
dnbd3_reply_t reply;
int result;
sock = container_of(work, struct dnbd3_sock, receive_worker);
debug_sock(sock, "receive worker is starting");
while(1) { // loop until socket returns 0
result = dnbd3_receive_cmd(sock, &reply);
if (result == -EAGAIN) {
continue;
} else if (result <= 0) {
error_sock(sock, "connection to server lost %d", result);
goto error;
}
switch (reply.cmd) {
case CMD_GET_BLOCK:
result = dnbd3_receive_cmd_get_block_mq(sock, &reply);
if (result <= 0) {
error_sock(sock, "receive cmd get block mq failed %d",
result);
goto error;
}
continue;
case CMD_GET_SERVERS:
result = dnbd3_receive_cmd_get_servers(sock, &reply);
if (result <= 0) {
error_sock(sock, "receive cmd get servers failed %d",
result);
goto error;
}
break;
case CMD_LATEST_RID:
result = dnbd3_receive_cmd_latest_rid(sock, &reply);
if (result <= 0) {
error_sock(sock, "receive cmd latest rid failed %d",
result);
goto error;
}
break;
case CMD_KEEPALIVE:
if (reply.size != 0) {
error_sock(sock, "got keep alive packet with payload");
goto error;
}
debug_sock(sock, "keep alive received");
break;
case CMD_SELECT_IMAGE:
result = dnbd3_receive_cmd_select_image(sock, &reply);
if (result <= 0) {
error_sock(sock, "receive cmd select image failed %d",
result);
goto error;
}
break;
default:
warn_sock(sock, "unknown command received");
break;
}
error:
if (result == 0) {
info_sock(sock, "result is 0, socket seems to be down");
dnbd3_sock_release(sock);
break;
} else if (result < 0) {
/* discovery takes care of to many failures */
sock->server->failures++;
warn_sock(sock, "receive error happened %d, total failures %d",
result, sock->server->failures);
}
debug_sock(sock, "receive completed, waiting for next receive");
}
debug_sock(sock, "receive work queue is stopped");
}
/**
* dnbd3_receive_worker - sends a keepalive
* @work: the work used to get the dndb3_sock
*/
static void dnbd3_keepalive_worker(struct work_struct *work)
{
struct dnbd3_sock *sock;
sock = container_of(work, struct dnbd3_sock, keepalive_worker);
if (dnbd3_is_sock_alive(*sock)) {
debug_sock(sock, "starting keepalive worker");
if (mutex_trylock(&sock->tx_lock)) {
dnbd3_send_request_cmd(sock, CMD_KEEPALIVE);
mutex_unlock(&sock->tx_lock);
}
}
}
/**
* dnbd3_compare_servers - comparator for the server
* @lhs: left hand sign
* @rhs: right hand sign
*/
static int dnbd3_compare_servers(const void *lhs, const void *rhs) {
uint64_t l, r;
struct dnbd3_server *lhs_server = *((struct dnbd3_server **) lhs);
struct dnbd3_server *rhs_server = *((struct dnbd3_server **) rhs);
l = lhs_server->host.type != 0 ? lhs_server->avg_rtt
: RTT_UNREACHABLE * 2;
r = rhs_server->host.type != 0 ? rhs_server->avg_rtt
: RTT_UNREACHABLE * 2;
return l - r;
}
/**
* dnbd3_sort_server - sort the alt server according to their avg rtt
* @dev: the dndb3 device
*
* the returned array has to be freed with kfree
*/
static struct dnbd3_server **dnbd3_sort_server(struct dnbd3_device *dev) {
int i;
struct dnbd3_server **sorted_servers = kmalloc(NUMBER_SERVERS *
sizeof(struct dnbd3_server *), GFP_KERNEL);
if (!sorted_servers) {
debug_dev(dev, "kmalloc failed");
return NULL;
}
for (i = 0; i < NUMBER_SERVERS; i++) {
sorted_servers[i] = &dev->alt_servers[i];
}
sort(sorted_servers, NUMBER_SERVERS, sizeof(struct dnbd3_server *),
&dnbd3_compare_servers, NULL);
return sorted_servers;
}
//static int dnbd3_panic_connect(struct dnbd3_device *dev)
//{
// struct dnbd3_server *working = NULL;
// int i;
// debug_dev(dev, "panic connect");
// for (i = 0; i < NUMBER_CONNECTIONS; i++) {
// if (dnbd3_is_sock_alive(dev->socks[i])) {
// working = dev->socks[i].server;
// debug_server(dev, working, "found server for panic");
// }
// }
// if (working == NULL) {
// for (i = 0; i < NUMBER_SERVERS; i++) {
// if (!dnbd3_socket_connect(&dev->socks[0],
// &dev->alt_servers[i])) {
// working = &dev->alt_servers[i];
// debug_server(dev, working, "found server for panic");
// }
// }
// }
// if (working == NULL) {
// return -ENOENT;
// }
// for (i = 0; i < NUMBER_CONNECTIONS; i++) {
// if (dev->socks[i].server != working) {
// dnbd3_socket_connect(&dev->socks[i], working);
// }
// }
// return 0;
//}
/**
* dnbd3_compare_servers - comparator for the connecion plan
* @lhs: left hand sign
* @rhs: right hand sign
*/
static int dnbd3_compare_plan(const void *lhs, const void *rhs)
{
uint64_t l, r;
struct dnbd3_server *lhs_server = *((struct dnbd3_server **) lhs);
struct dnbd3_server *rhs_server = *((struct dnbd3_server **) rhs);
uint8_t *l_addr = lhs_server->host.addr;
uint8_t *r_addr = rhs_server->host.addr;
l = l_addr[0] + l_addr[1] + l_addr[2] + l_addr[3] + l_addr[4] +
l_addr[5] + l_addr[6] + l_addr[7] + l_addr[8] +
l_addr[9] + l_addr[10] + l_addr[11] + l_addr[12] +
l_addr[13] + l_addr[14] + l_addr[15] +
lhs_server->host.port;
r = r_addr[0] + r_addr[1] + r_addr[2] + r_addr[3] + r_addr[4] +
r_addr[5] + r_addr[6] + r_addr[7] + r_addr[8] +
r_addr[9] + r_addr[10] + r_addr[11] + r_addr[12] +
r_addr[13] + r_addr[14] + r_addr[15] +
rhs_server->host.port;
return l - r;
}
static void dnbd3_lock_all_socks(struct dnbd3_device *dev)
{
int i;
for (i = 0; i < dev->number_connections; i++) {
mutex_lock(&dev->socks[i].tx_lock);
}
}
static void dnbd3_unlock_all_socks(struct dnbd3_device *dev)
{
int i;
for (i = 0; i < dev->number_connections; i++) {
mutex_unlock(&dev->socks[i].tx_lock);
}
}
static void dnbd3_print_conenction_plan(struct dnbd3_device *dev,
struct dnbd3_server **plan)
{
int i;
debug_dev(dev, "connection plan:");
for (i = 0; i < dev->number_connections; i++) {
debug_server(dev, plan[i], "server %d with avg rtt %llu:",
i, plan[i]->avg_rtt);
}
}
/**
* dnbd3_adjust_connections - create a connection plan and connect
* @dev: the dnbd3 device
*
* 1. sort the alt server after the avg rtt
* 2. create a connection plan
* 3. connect the plan
*/
static int dnbd3_adjust_connections(struct dnbd3_device *dev) {
int i, j, fallback, alive;
struct dnbd3_server **plan;
struct dnbd3_server **servers = dnbd3_sort_server(dev);
plan = kmalloc(sizeof(struct dnbd3_server *) * dev->number_connections,
GFP_KERNEL);
if (!plan) {
error_dev(dev, "kmalloc failed");
}
if (servers && servers[0]->host.type != 0) {
plan[0] = servers[0];
fallback = 0;
j = 1;
for (i = 1; i < dev->number_connections; i++) {
if (servers[j]->host.type != 0 &&
servers[j]->avg_rtt < RTT_UNKNOWN) {
if (RTT_FACTOR(plan[i - 1]->avg_rtt) >
servers[j]->avg_rtt) {
plan[i] = servers[j];
j++;
} else {
plan[i] = plan[fallback];
fallback++;
}
} else {
plan[i] = plan[fallback];
fallback++;
}
}
kfree(servers);
sort(plan, dev->number_connections, sizeof(struct dnbd3_server *),
&dnbd3_compare_plan, NULL);
dnbd3_print_conenction_plan(dev, plan);
dnbd3_lock_all_socks(dev);
alive = 0;
for (i = 0; i < dev->number_connections; i++) {
if (plan[i] != dev->socks[i].server ||
!dnbd3_is_sock_alive(dev->socks[i])) {
if (dnbd3_is_sock_alive(dev->socks[i])) {
dnbd3_socket_disconnect(&dev->socks[i]);
}
if (!dnbd3_socket_connect(&dev->socks[i],
plan[i])) {
alive++;
}
} else {
alive++;
}
}
dnbd3_unlock_all_socks(dev);
if (alive == 0) {
return -EIO;
}
return 0;
} else { /* there is nothing to connect */
debug_dev(dev, "failed to adjust connections");
if (servers) {
kfree(servers);
}
return -ENONET;
}
}
/**
* dnbd3_panic_worker - handle panicked sockets
* @work: the work used to get the dndb3_device
*
* 1. disconnect panicked socket
* 2. reconnect to good alternative
* 3. if no socket is connected do a panic_connect
*/
static void dnbd3_panic_worker(struct work_struct *work)
{
struct dnbd3_device *dev;
bool panic = false;
int i;
int sock_alive = 0;
dev = container_of(work, struct dnbd3_device, panic_worker);
for (i = 0; i < dev->number_connections; i++) {
if (!dnbd3_is_sock_alive(dev->socks[i])
|| dev->socks[i].server->failures > 1000) {
panic = true;
dnbd3_set_rtt_unreachable(dev->socks[i].server);
// dnbd3_socket_disconnect(&dev->socks[i]);
} else {
sock_alive++;
}
}
if (panic) {
warn_dev(dev, "panicked, connections still alive %d",
sock_alive);
mutex_lock(&dev->device_lock);
if (dnbd3_adjust_connections(dev)) {
error_dev(dev, "failed to connect to any server");
dev->connected = false;
}
mutex_unlock(&dev->device_lock);
}
}
/**
* dnbd3_meassure_rtt - meassure the rtt of a server
* @dev: the device this server belongs to
* @server: the server to meassure
*/
static int dnbd3_meassure_rtt(struct dnbd3_device *dev,
struct dnbd3_server *server)
{
struct timeval start, end;
dnbd3_reply_t reply;
int result;
uint64_t rtt = RTT_UNREACHABLE;
struct dnbd3_sock sock = {
.sock_nr = dev->number_connections,
.sock = NULL,
.device = dev,
.server = server
};
result = __dnbd3_socket_connect(&sock, server);
if (result) {
error_sock(&sock, "socket connect failed in rtt measurement");
goto error;
}
result = dnbd3_send_request_cmd(&sock, CMD_SELECT_IMAGE);
if (result <= 0) {
error_sock(&sock, "request select image failed in rtt measurement");
goto error;
}
result = dnbd3_receive_cmd(&sock, &reply);
if (result <= 0) {
error_sock(&sock, "receive select image failed in rtt measurement");
goto error;
}
if (reply.magic != dnbd3_packet_magic || reply.cmd != CMD_SELECT_IMAGE
|| reply.size < 4) {
error_sock(&sock, "receive select image wrong header in rtt measurement");
result = -EIO;
goto error;
}
result = dnbd3_receive_cmd_select_image(&sock, &reply);
if (result <= 0) {
error_sock(&sock, "receive data select image failed in rtt measurement");
goto error;
}
do_gettimeofday(&start);
result = dnbd3_send_request_cmd(&sock, CMD_GET_BLOCK);
if (result <= 0) {
error_sock(&sock, "request test block failed in rtt measurement");
goto error;
}
result = dnbd3_receive_cmd(&sock, &reply);
if (reply.magic != dnbd3_packet_magic|| reply.cmd != CMD_GET_BLOCK
|| reply.size != RTT_BLOCK_SIZE) {
error_sock(&sock, "receive header cmd test block failed in rtt measurement");
result = -EIO;
goto error;
}
result = dnbd3_receive_cmd_get_block_test(&sock, &reply);
if (result <= 0) {
error_sock(&sock, "receive test block failed in rtt measurement");
goto error;
}
do_gettimeofday(&end); // end rtt measurement
rtt = (uint64_t)((end.tv_sec - start.tv_sec) * 1000000ull
+ (end.tv_usec - start.tv_usec));
error:
sock.server->rtts[dev->discovery_count % 4] = rtt;
sock.server->avg_rtt = dnbd3_average_rtt(sock.server);
debug_sock(&sock, "meassured rrt: %llu; avg_rtt: %llu", rtt,
sock.server->avg_rtt);
if (result <= 0) {
server->failures++;
}
if (sock.sock) {
kernel_sock_shutdown(sock.sock, SHUT_RDWR);
sock.server = NULL;
sock_release(sock.sock);
sock.sock = NULL;
}
return result;
}
/**
* dnbd3_merge_new_server - merge the new server into the alt server list
* @dev: the device
* @new_server: the new server list to merge
*/
static void dnbd3_merge_new_server(struct dnbd3_device *dev,
dnbd3_server_entry_t *new_server)
{
int i;
struct dnbd3_server *existing_server, *free_server, *failed_server;
existing_server = NULL;
free_server = NULL;
failed_server = NULL;
/* find servers in alternative servers */
for (i = 0; i < NUMBER_SERVERS; i++) {
if ((new_server->host.type == dev->alt_servers[i].host.type)
&& (new_server->host.port == dev->alt_servers[i].host.port)
&& (0 == memcmp(new_server->host.addr,
dev->alt_servers[i].host.addr,
(new_server->host.type == HOST_IP4 ? 4 : 16)
))) {
existing_server = &dev->alt_servers[i];
} else if (dev->alt_servers[i].host.type == 0) {
free_server = &dev->alt_servers[i];
} else if (dev->alt_servers[i].failures > 20) {
failed_server = &dev->alt_servers[i];
}
}
if (existing_server) {
if (new_server->failures == 1) { /* remove is requested */
info_server(dev, new_server,
"remove server is requested");
// adjust connection will remove it later
existing_server->host.type = 0;
dnbd3_set_rtt_unreachable(existing_server);
}
// existing_server->failures = 0; // reset failure count
return;
} else if (free_server) {
//TODO disconnect the server if it is connected
free_server->host = new_server->host;
} else if (failed_server) {
failed_server->host = new_server->host;
free_server = failed_server;
} else {
/* no server found to replace */
return;
}
info_server(dev, free_server, "got new alternative server");
free_server->failures = 0;
free_server->protocol_version = 0;
dnbd3_set_rtt_unknown(free_server);
}
/**
* dnbd3_discovery_worker - handle discovery
* @work: the work used to get the dndb3_device
*
* 1. check if new servers are available and set them to alternative servers
* 2. meassure the rtt for all available servers
* 3. adjust the connections
*/
static void dnbd3_discovery_worker(struct work_struct *work)
{
struct dnbd3_device *dev;
int i;
struct dnbd3_server *server;
dnbd3_server_entry_t *new_server;
dev = container_of(work, struct dnbd3_device, discovery_worker);
debug_dev(dev, "starting discovery worker new server num is %d",
dev->new_servers_num);
if (dev->new_servers_num) {
mutex_lock(&dev->device_lock);
for (i = 0; i < dev->new_servers_num; i++) {
new_server = &dev->new_servers[i];
if (new_server->host.type != 0) {
dnbd3_merge_new_server(dev, new_server);
}
}
dev->new_servers_num = 0;
mutex_unlock(&dev->device_lock);
}
// measure rtt for all alt servers
for (i = 0; i < NUMBER_SERVERS; i++) {
server = &dev->alt_servers[i];
if (server->host.type) {
/* failure count is divided*/
server->failures = server->failures / 2;
if (dnbd3_meassure_rtt(dev, server) <= 0) {
server->failures = server->failures * 2;
warn_server(dev, server,
"failed to meassure rtt");
}
}
}
mutex_lock(&dev->device_lock);
if (dnbd3_adjust_connections(dev)) {
error_dev(dev, "failed to connect to any server");
dev->connected = false;
}
mutex_unlock(&dev->device_lock);
dev->discovery_count++;
}
/*
* Connect and disconnect
*/
/**
* __dnbd3_socket_connect - internal connect a socket to a server
* @sock: the socket to connect
* @server: the server
*/
static int __dnbd3_socket_connect(struct dnbd3_sock *sock,
struct dnbd3_server *server)
{
int result = 0;
struct timeval timeout;
if (server->host.port == 0 || server->host.type == 0) {
error_sock(sock, "host or port not set");
return -EIO;
}
if (sock->sock) {
warn_sock(sock, "already connected");
return -EIO;
}
timeout.tv_sec = SOCKET_TIMEOUT_CLIENT_DATA;
timeout.tv_usec = 0;
result = dnbd3_sock_create(server->host.type, SOCK_STREAM, IPPROTO_TCP,
&sock->sock);
if (result < 0) {
error_sock(sock, "could not create socket");
goto error;
}
kernel_setsockopt(sock->sock, SOL_SOCKET, SO_SNDTIMEO, (char *)&timeout,
sizeof(timeout));
kernel_setsockopt(sock->sock, SOL_SOCKET, SO_RCVTIMEO, (char *)&timeout,
sizeof(timeout));
sock->sock->sk->sk_allocation = GFP_NOIO;
if (server->host.type == HOST_IP4) {
struct sockaddr_in sin;
memset(&sin, 0, sizeof(sin));
sin.sin_family = AF_INET;
memcpy(&(sin.sin_addr), server->host.addr, 4);
sin.sin_port = server->host.port;
result = kernel_connect(sock->sock, (struct sockaddr *)&sin,
sizeof(sin), 0);
if (result != 0) {
error_sock(sock, "connection to host failed");
goto error;
}
} else {
struct sockaddr_in6 sin;
memset(&sin, 0, sizeof(sin));
sin.sin6_family = AF_INET6;
memcpy(&(sin.sin6_addr), server->host.addr, 16);
sin.sin6_port = server->host.port;
result = kernel_connect(sock->sock, (struct sockaddr *)&sin,
sizeof(sin), 0);
if (result != 0){
error_sock(sock, "connection to host failed");
goto error;
}
}
return 0;
error:
if (sock->sock) {
dnbd3_sock_release(sock);
}
return result;
}
/**
* dnbd3_socket_connect - connect a socket to a server
* @sock: the socket
* @server: the server to connect
*
* 1. connects the server to the socket
* 2. select the image
* 3. start receiver_worker and keepalive_worker
*/
static int dnbd3_socket_connect(struct dnbd3_sock *sock,
struct dnbd3_server *server)
{
int result = -EIO;
dnbd3_reply_t reply;
struct dnbd3_device *dev = sock->device;
sock->server = server;
debug_sock(sock, "socket connect");
result = __dnbd3_socket_connect(sock, server);
if (result) {
error_sock(sock, "connection to socket failed");
result = -EIO;
goto error;
}
if (!sock->sock) {
error_sock(sock, "socket is not connected");
server->failures++;
result = -EIO;
goto error;
}
result = dnbd3_send_request_cmd(sock, CMD_SELECT_IMAGE);
if (result <= 0) {
error_sock(sock, "connection to image %s failed", dev->imgname);
result = -EIO;
goto error;
}
result = dnbd3_receive_cmd(sock, &reply);
if (result <= 0) {
error_sock(sock, "receive cmd to image %s failed",
dev->imgname);
result = -EIO;
goto error;
}
if (reply.magic != dnbd3_packet_magic || reply.cmd != CMD_SELECT_IMAGE
|| reply.size < 4) {
error_sock(sock, "receive select image wrong header %s",
dev->imgname);
result = -EIO;
goto error;
}
result = dnbd3_receive_cmd_select_image(sock, &reply);
if (result <= 0) {
error_sock(sock, "receive cmd select image %s failed",
dev->imgname);
result = -EIO;
goto error;
}
debug_sock(sock, "connected to image %s, filesize %llu", dev->imgname,
dev->reported_size);
// start the receiver
INIT_WORK(&sock->receive_worker, dnbd3_receive_worker);
queue_work(dnbd3_wq, &sock->receive_worker);
INIT_WORK(&sock->keepalive_worker, dnbd3_keepalive_worker);
/* request alternative servers receiver will handle this */
if (dnbd3_send_request_cmd(sock, CMD_GET_SERVERS) <= 0) {
error_sock(sock, "failed to get servers in discovery");
}
return 0;
error:
server->failures++;
if (sock->sock) {
kernel_sock_shutdown(sock->sock, SHUT_RDWR);
cancel_work_sync(&sock->receive_worker);
dnbd3_sock_release(sock);
}
return result;
}
/**
* dnbd3_socket_disconnect - disconnect a socket
* @sock: the socket to disconnect
*/
static int dnbd3_socket_disconnect(struct dnbd3_sock *sock)
{
cancel_work_sync(&sock->keepalive_worker);
debug_sock(sock, "socket disconnect");
/*
* Important sequence to shut down socket
* 1. kernel_sock_shutdown
* socket shutdown, receiver which block ins socket receive
* returns 0
* 2. cancel_work_sync(receiver)
* wait for the receiver to finish, so the socket is not used
* anymore
* 3. sock_release
* release the socket and set to NULL
*/
if (sock->sock) {
kernel_sock_shutdown(sock->sock, SHUT_RDWR);
}
cancel_work_sync(&sock->receive_worker);
dndb3_reque_busy_requests(sock);
if (sock->sock) {
dnbd3_sock_release(sock);
}
sock->server = NULL;
return 0;
}
/**
* dnbd3_net_connect - connect device
* @dev: the device to connect
*
* dnbd3_device.alt_servers[0] must be set
*/
int dnbd3_net_connect(struct dnbd3_device *dev)
{
int i, result;
debug_dev(dev, "connecting to server");
dev->socks = kmalloc(sizeof(struct dnbd3_sock) *
dev->number_connections, GFP_KERNEL);
if (!dev->socks) {
error_dev(dev, "kmalloc failed");
}
memset(dev->socks, 0, sizeof(struct dnbd3_sock) *
dev->number_connections);
for (i = 0; i < dev->number_connections; i++) {
dev->socks[i].device = dev;
dev->socks[i].sock_nr = i;
mutex_init(&dev->socks[i].tx_lock);
}
debug_dev(dev, "set nr hw queues to %d", dev->number_connections);
blk_mq_update_nr_hw_queues(&dev->tag_set, dev->number_connections);
if (dev->alt_servers[0].host.type == 0) {
return -ENONET;
}
result = dnbd3_adjust_connections(dev);
if (result) {
error_dev(dev, "failed to connect to initial server");
dnbd3_net_disconnect(dev);
return -ENOENT;
}
dev->connected = true;
debug_dev(dev, "connected, starting workers");
INIT_WORK(&dev->discovery_worker, dnbd3_discovery_worker);
INIT_WORK(&dev->panic_worker, dnbd3_panic_worker);
timer_setup(&dev->timer, dnbd3_timer, 0);
dev->timer.expires = jiffies + HZ;
add_timer(&dev->timer);
// alt_server[0] is the initial server
// result = dnbd3_server_connect(dev, &dev->alt_servers[0]);
// if (result) {
// error_dev(dev, "failed to connect to initial server");
// result = -ENOENT;
// dev->imgname = NULL;
// dev->socks[0].server = NULL;
// }
return result;
}
/**
* dnbd3_net_disconnect - disconnect device
* @dev: the device to disconnect
*/
int dnbd3_net_disconnect(struct dnbd3_device *dev)
{
int i;
int result = 0;
if (!dev->socks) {
info_dev(dev, "not connected");
}
del_timer_sync(&dev->timer);
/* be sure it does not recover while disconnecting */
cancel_work_sync(&dev->discovery_worker);
cancel_work_sync(&dev->panic_worker);
for (i = 0; i < dev->number_connections; i++) {
if (dev->socks[i].sock) {
mutex_lock(&dev->socks[i].tx_lock);
if (dnbd3_socket_disconnect(&dev->socks[i])) {
result = -EIO;
}
mutex_unlock(&dev->socks[i].tx_lock);
mutex_destroy(&dev->socks[i].tx_lock);
}
}
kfree(dev->socks);
dev->socks = NULL;
dev->connected = false;
return result;
}