summaryrefslogtreecommitdiffstats
path: root/net/tls/tls_main.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls/tls_main.c')
-rw-r--r--net/tls/tls_main.c142
1 files changed, 110 insertions, 32 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 4674e57e66b0..f208f8455ef2 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -261,24 +261,36 @@ void tls_ctx_free(struct tls_context *ctx)
kfree(ctx);
}
-static void tls_sk_proto_close(struct sock *sk, long timeout)
+static void tls_ctx_free_deferred(struct work_struct *gc)
{
- struct tls_context *ctx = tls_get_ctx(sk);
- long timeo = sock_sndtimeo(sk, 0);
- void (*sk_proto_close)(struct sock *sk, long timeout);
- bool free_ctx = false;
-
- lock_sock(sk);
- sk_proto_close = ctx->sk_proto_close;
+ struct tls_context *ctx = container_of(gc, struct tls_context, gc);
- if (ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD)
- goto skip_tx_cleanup;
+ /* Ensure any remaining work items are completed. The sk will
+ * already have lost its tls_ctx reference by the time we get
+ * here so no xmit operation will actually be performed.
+ */
+ if (ctx->tx_conf == TLS_SW) {
+ tls_sw_cancel_work_tx(ctx);
+ tls_sw_free_ctx_tx(ctx);
+ }
- if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) {
- free_ctx = true;
- goto skip_tx_cleanup;
+ if (ctx->rx_conf == TLS_SW) {
+ tls_sw_strparser_done(ctx);
+ tls_sw_free_ctx_rx(ctx);
}
+ tls_ctx_free(ctx);
+}
+
+static void tls_ctx_free_wq(struct tls_context *ctx)
+{
+ INIT_WORK(&ctx->gc, tls_ctx_free_deferred);
+ schedule_work(&ctx->gc);
+}
+
+static void tls_sk_proto_cleanup(struct sock *sk,
+ struct tls_context *ctx, long timeo)
+{
if (unlikely(sk->sk_write_pending) &&
!wait_on_pending_writer(sk, &timeo))
tls_handle_open_record(sk, 0);
@@ -287,7 +299,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
if (ctx->tx_conf == TLS_SW) {
kfree(ctx->tx.rec_seq);
kfree(ctx->tx.iv);
- tls_sw_free_resources_tx(sk);
+ tls_sw_release_resources_tx(sk);
#ifdef CONFIG_TLS_DEVICE
} else if (ctx->tx_conf == TLS_HW) {
tls_device_free_resources_tx(sk);
@@ -295,26 +307,67 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
}
if (ctx->rx_conf == TLS_SW)
- tls_sw_free_resources_rx(sk);
+ tls_sw_release_resources_rx(sk);
#ifdef CONFIG_TLS_DEVICE
if (ctx->rx_conf == TLS_HW)
tls_device_offload_cleanup_rx(sk);
-
- if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW) {
-#else
- {
#endif
- tls_ctx_free(ctx);
- ctx = NULL;
+}
+
+static void tls_sk_proto_unhash(struct sock *sk)
+{
+ struct inet_connection_sock *icsk = inet_csk(sk);
+ long timeo = sock_sndtimeo(sk, 0);
+ struct tls_context *ctx;
+
+ if (unlikely(!icsk->icsk_ulp_data)) {
+ if (sk->sk_prot->unhash)
+ sk->sk_prot->unhash(sk);
}
-skip_tx_cleanup:
+ ctx = tls_get_ctx(sk);
+ tls_sk_proto_cleanup(sk, ctx, timeo);
+ write_lock_bh(&sk->sk_callback_lock);
+ icsk->icsk_ulp_data = NULL;
+ sk->sk_prot = ctx->sk_proto;
+ write_unlock_bh(&sk->sk_callback_lock);
+
+ if (ctx->sk_proto->unhash)
+ ctx->sk_proto->unhash(sk);
+ tls_ctx_free_wq(ctx);
+}
+
+static void tls_sk_proto_close(struct sock *sk, long timeout)
+{
+ struct inet_connection_sock *icsk = inet_csk(sk);
+ struct tls_context *ctx = tls_get_ctx(sk);
+ long timeo = sock_sndtimeo(sk, 0);
+ bool free_ctx;
+
+ if (ctx->tx_conf == TLS_SW)
+ tls_sw_cancel_work_tx(ctx);
+
+ lock_sock(sk);
+ free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
+
+ if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
+ tls_sk_proto_cleanup(sk, ctx, timeo);
+
+ write_lock_bh(&sk->sk_callback_lock);
+ if (free_ctx)
+ icsk->icsk_ulp_data = NULL;
+ sk->sk_prot = ctx->sk_proto;
+ write_unlock_bh(&sk->sk_callback_lock);
release_sock(sk);
- sk_proto_close(sk, timeout);
- /* free ctx for TLS_HW_RECORD, used by tcp_set_state
- * for sk->sk_prot->unhash [tls_hw_unhash]
- */
+ if (ctx->tx_conf == TLS_SW)
+ tls_sw_free_ctx_tx(ctx);
+ if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
+ tls_sw_strparser_done(ctx);
+ if (ctx->rx_conf == TLS_SW)
+ tls_sw_free_ctx_rx(ctx);
+ ctx->sk_proto_close(sk, timeout);
+
if (free_ctx)
tls_ctx_free(ctx);
}
@@ -526,6 +579,8 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
{
#endif
rc = tls_set_sw_offload(sk, ctx, 1);
+ if (rc)
+ goto err_crypto_info;
conf = TLS_SW;
}
} else {
@@ -537,13 +592,13 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
{
#endif
rc = tls_set_sw_offload(sk, ctx, 0);
+ if (rc)
+ goto err_crypto_info;
conf = TLS_SW;
}
+ tls_sw_strparser_arm(sk, ctx);
}
- if (rc)
- goto err_crypto_info;
-
if (tx)
ctx->tx_conf = conf;
else
@@ -607,6 +662,7 @@ static struct tls_context *create_ctx(struct sock *sk)
ctx->setsockopt = sk->sk_prot->setsockopt;
ctx->getsockopt = sk->sk_prot->getsockopt;
ctx->sk_proto_close = sk->sk_prot->close;
+ ctx->unhash = sk->sk_prot->unhash;
return ctx;
}
@@ -730,6 +786,7 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt;
prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt;
prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close;
+ prot[TLS_BASE][TLS_BASE].unhash = tls_sk_proto_unhash;
prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg;
@@ -747,16 +804,20 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
#ifdef CONFIG_TLS_DEVICE
prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
+ prot[TLS_HW][TLS_BASE].unhash = base->unhash;
prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg;
prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage;
prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
+ prot[TLS_HW][TLS_SW].unhash = base->unhash;
prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg;
prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage;
prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
+ prot[TLS_BASE][TLS_HW].unhash = base->unhash;
prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
+ prot[TLS_SW][TLS_HW].unhash = base->unhash;
prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
#endif
@@ -764,7 +825,6 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash;
prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash;
- prot[TLS_HW_RECORD][TLS_HW_RECORD].close = tls_sk_proto_close;
}
static int tls_init(struct sock *sk)
@@ -773,7 +833,7 @@ static int tls_init(struct sock *sk)
int rc = 0;
if (tls_hw_prot(sk))
- goto out;
+ return 0;
/* The TLS ulp is currently supported only for TCP sockets
* in ESTABLISHED state.
@@ -784,21 +844,38 @@ static int tls_init(struct sock *sk)
if (sk->sk_state != TCP_ESTABLISHED)
return -ENOTSUPP;
+ tls_build_proto(sk);
+
/* allocate tls context */
+ write_lock_bh(&sk->sk_callback_lock);
ctx = create_ctx(sk);
if (!ctx) {
rc = -ENOMEM;
goto out;
}
- tls_build_proto(sk);
ctx->tx_conf = TLS_BASE;
ctx->rx_conf = TLS_BASE;
+ ctx->sk_proto = sk->sk_prot;
update_sk_prot(sk, ctx);
out:
+ write_unlock_bh(&sk->sk_callback_lock);
return rc;
}
+static void tls_update(struct sock *sk, struct proto *p)
+{
+ struct tls_context *ctx;
+
+ ctx = tls_get_ctx(sk);
+ if (likely(ctx)) {
+ ctx->sk_proto_close = p->close;
+ ctx->sk_proto = p;
+ } else {
+ sk->sk_prot = p;
+ }
+}
+
void tls_register_device(struct tls_device *device)
{
spin_lock_bh(&device_spinlock);
@@ -819,6 +896,7 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
.name = "tls",
.owner = THIS_MODULE,
.init = tls_init,
+ .update = tls_update,
};
static int __init tls_register(void)