Split TLSv1.3 record protection from record layer.
authorjsing <jsing@openbsd.org>
Sun, 21 Mar 2021 17:25:17 +0000 (17:25 +0000)
committerjsing <jsing@openbsd.org>
Sun, 21 Mar 2021 17:25:17 +0000 (17:25 +0000)
This makes the TLSv1.2 and TLSv1.3 record layers more consistent and while
it is not currently necessary from a functionality perspective, it makes
for more readable and simpler code.

ok inoguchi@ tb@

lib/libssl/tls13_record_layer.c

index bbecc60..4be4bad 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: tls13_record_layer.c,v 1.58 2021/01/05 17:49:04 tb Exp $ */
+/* $OpenBSD: tls13_record_layer.c,v 1.59 2021/03/21 17:25:17 jsing Exp $ */
 /*
  * Copyright (c) 2018, 2019 Joel Sing <jsing@openbsd.org>
  *
@@ -25,6 +25,41 @@ static ssize_t tls13_record_layer_write_chunk(struct tls13_record_layer *rl,
 static ssize_t tls13_record_layer_write_record(struct tls13_record_layer *rl,
     uint8_t content_type, const uint8_t *content, size_t content_len);
 
+struct tls13_record_protection {
+       EVP_AEAD_CTX aead_ctx;
+       struct tls13_secret iv;
+       struct tls13_secret nonce;
+       uint8_t seq_num[TLS13_RECORD_SEQ_NUM_LEN];
+};
+
+struct tls13_record_protection *
+tls13_record_protection_new(void)
+{
+       return calloc(1, sizeof(struct tls13_record_protection));
+}
+
+void
+tls13_record_protection_clear(struct tls13_record_protection *rp)
+{
+       EVP_AEAD_CTX_cleanup(&rp->aead_ctx);
+
+       tls13_secret_cleanup(&rp->iv);
+       tls13_secret_cleanup(&rp->nonce);
+
+       memset(rp->seq_num, 0, sizeof(rp->seq_num));
+}
+
+void
+tls13_record_protection_free(struct tls13_record_protection *rp)
+{
+       if (rp == NULL)
+               return;
+
+       tls13_record_protection_clear(rp);
+
+       freezero(rp, sizeof(struct tls13_record_protection));
+}
+
 struct tls13_record_layer {
        uint16_t legacy_version;
 
@@ -75,14 +110,8 @@ struct tls13_record_layer {
        /* Record protection. */
        const EVP_MD *hash;
        const EVP_AEAD *aead;
-       EVP_AEAD_CTX read_aead_ctx;
-       EVP_AEAD_CTX write_aead_ctx;
-       struct tls13_secret read_iv;
-       struct tls13_secret write_iv;
-       struct tls13_secret read_nonce;
-       struct tls13_secret write_nonce;
-       uint8_t read_seq_num[TLS13_RECORD_SEQ_NUM_LEN];
-       uint8_t write_seq_num[TLS13_RECORD_SEQ_NUM_LEN];
+       struct tls13_record_protection *read;
+       struct tls13_record_protection *write;
 
        /* Callbacks. */
        struct tls13_record_layer_callbacks cb;
@@ -120,13 +149,23 @@ tls13_record_layer_new(const struct tls13_record_layer_callbacks *callbacks,
        struct tls13_record_layer *rl;
 
        if ((rl = calloc(1, sizeof(struct tls13_record_layer))) == NULL)
-               return NULL;
+               goto err;
+
+       if ((rl->read = tls13_record_protection_new()) == NULL)
+               goto err;
+       if ((rl->write = tls13_record_protection_new()) == NULL)
+               goto err;
 
        rl->legacy_version = TLS1_2_VERSION;
        rl->cb = *callbacks;
        rl->cb_arg = cb_arg;
 
        return rl;
+
+ err:
+       tls13_record_layer_free(rl);
+
+       return NULL;
 }
 
 void
@@ -143,13 +182,8 @@ tls13_record_layer_free(struct tls13_record_layer *rl)
 
        tls13_record_layer_rbuf_free(rl);
 
-       EVP_AEAD_CTX_cleanup(&rl->read_aead_ctx);
-       EVP_AEAD_CTX_cleanup(&rl->write_aead_ctx);
-
-       tls13_secret_cleanup(&rl->read_iv);
-       tls13_secret_cleanup(&rl->write_iv);
-       tls13_secret_cleanup(&rl->read_nonce);
-       tls13_secret_cleanup(&rl->write_nonce);
+       tls13_record_protection_free(rl->read);
+       tls13_record_protection_free(rl->write);
 
        freezero(rl, sizeof(struct tls13_record_layer));
 }
@@ -430,32 +464,28 @@ tls13_record_layer_phh(struct tls13_record_layer *rl, CBS *cbs)
 }
 
 static int
-tls13_record_layer_set_traffic_key(const EVP_AEAD *aead, EVP_AEAD_CTX *aead_ctx,
-    const EVP_MD *hash, struct tls13_secret *iv, struct tls13_secret *nonce,
-    struct tls13_secret *traffic_key)
+tls13_record_layer_set_traffic_key(const EVP_AEAD *aead, const EVP_MD *hash,
+    struct tls13_record_protection *rp, struct tls13_secret *traffic_key)
 {
        struct tls13_secret context = { .data = "", .len = 0 };
        struct tls13_secret key = { .data = NULL, .len = 0 };
        int ret = 0;
 
-       EVP_AEAD_CTX_cleanup(aead_ctx);
+       tls13_record_protection_clear(rp);
 
-       tls13_secret_cleanup(iv);
-       tls13_secret_cleanup(nonce);
-
-       if (!tls13_secret_init(iv, EVP_AEAD_nonce_length(aead)))
+       if (!tls13_secret_init(&rp->iv, EVP_AEAD_nonce_length(aead)))
                goto err;
-       if (!tls13_secret_init(nonce, EVP_AEAD_nonce_length(aead)))
+       if (!tls13_secret_init(&rp->nonce, EVP_AEAD_nonce_length(aead)))
                goto err;
        if (!tls13_secret_init(&key, EVP_AEAD_key_length(aead)))
                goto err;
 
-       if (!tls13_hkdf_expand_label(iv, hash, traffic_key, "iv", &context))
+       if (!tls13_hkdf_expand_label(&rp->iv, hash, traffic_key, "iv", &context))
                goto err;
        if (!tls13_hkdf_expand_label(&key, hash, traffic_key, "key", &context))
                goto err;
 
-       if (!EVP_AEAD_CTX_init(aead_ctx, aead, key.data, key.len,
+       if (!EVP_AEAD_CTX_init(&rp->aead_ctx, aead, key.data, key.len,
            EVP_AEAD_DEFAULT_TAG_LENGTH, NULL))
                goto err;
 
@@ -471,20 +501,16 @@ int
 tls13_record_layer_set_read_traffic_key(struct tls13_record_layer *rl,
     struct tls13_secret *read_key)
 {
-       memset(rl->read_seq_num, 0, TLS13_RECORD_SEQ_NUM_LEN);
-
-       return tls13_record_layer_set_traffic_key(rl->aead, &rl->read_aead_ctx,
-           rl->hash, &rl->read_iv, &rl->read_nonce, read_key);
+       return tls13_record_layer_set_traffic_key(rl->aead, rl->hash,
+           rl->read, read_key);
 }
 
 int
 tls13_record_layer_set_write_traffic_key(struct tls13_record_layer *rl,
     struct tls13_secret *write_key)
 {
-       memset(rl->write_seq_num, 0, TLS13_RECORD_SEQ_NUM_LEN);
-
-       return tls13_record_layer_set_traffic_key(rl->aead, &rl->write_aead_ctx,
-           rl->hash, &rl->write_iv, &rl->write_nonce, write_key);
+       return tls13_record_layer_set_traffic_key(rl->aead, rl->hash,
+           rl->write, write_key);
 }
 
 static int
@@ -541,13 +567,13 @@ tls13_record_layer_open_record_protected(struct tls13_record_layer *rl)
                goto err;
        content_len = CBS_len(&enc_record);
 
-       if (!tls13_record_layer_update_nonce(&rl->read_nonce, &rl->read_iv,
-           rl->read_seq_num))
+       if (!tls13_record_layer_update_nonce(&rl->read->nonce, &rl->read->iv,
+           rl->read->seq_num))
                goto err;
 
-       if (!EVP_AEAD_CTX_open(&rl->read_aead_ctx,
+       if (!EVP_AEAD_CTX_open(&rl->read->aead_ctx,
            content, &out_len, content_len,
-           rl->read_nonce.data, rl->read_nonce.len,
+           rl->read->nonce.data, rl->read->nonce.len,
            CBS_data(&enc_record), CBS_len(&enc_record),
            CBS_data(&header), CBS_len(&header)))
                goto err;
@@ -557,7 +583,7 @@ tls13_record_layer_open_record_protected(struct tls13_record_layer *rl)
                goto err;
        }
 
-       if (!tls13_record_layer_inc_seq_num(rl->read_seq_num))
+       if (!tls13_record_layer_inc_seq_num(rl->read->seq_num))
                goto err;
 
        /*
@@ -718,8 +744,8 @@ tls13_record_layer_seal_record_protected(struct tls13_record_layer *rl,
        if (!CBB_finish(&cbb, &data, &data_len))
                goto err;
 
-       if (!tls13_record_layer_update_nonce(&rl->write_nonce,
-           &rl->write_iv, rl->write_seq_num))
+       if (!tls13_record_layer_update_nonce(&rl->write->nonce,
+           &rl->write->iv, rl->write->seq_num))
                goto err;
 
        /*
@@ -727,16 +753,16 @@ tls13_record_layer_seal_record_protected(struct tls13_record_layer *rl,
         * this would avoid a copy since the inner would be passed as two
         * separate pieces.
         */
-       if (!EVP_AEAD_CTX_seal(&rl->write_aead_ctx,
+       if (!EVP_AEAD_CTX_seal(&rl->write->aead_ctx,
            enc_record, &out_len, enc_record_len,
-           rl->write_nonce.data, rl->write_nonce.len,
+           rl->write->nonce.data, rl->write->nonce.len,
            inner, inner_len, header, header_len))
                goto err;
 
        if (out_len != enc_record_len)
                goto err;
 
-       if (!tls13_record_layer_inc_seq_num(rl->write_seq_num))
+       if (!tls13_record_layer_inc_seq_num(rl->write->seq_num))
                goto err;
 
        if (!tls13_record_set_data(rl->wrec, data, data_len))