Convert legacy TLS client to tls_key_share.
authorjsing <jsing@openbsd.org>
Thu, 6 Jan 2022 18:23:56 +0000 (18:23 +0000)
committerjsing <jsing@openbsd.org>
Thu, 6 Jan 2022 18:23:56 +0000 (18:23 +0000)
This requires adding DHE support to tls_key_share. In doing so,
tls_key_share_peer_public() has to lose the group argument and gains
an invalid_key argument. The one place that actually needs the group
check is tlsext_keyshare_client_parse(), so add code to do this.

ok inoguchi@ tb@

lib/libssl/s3_lib.c
lib/libssl/ssl_cert.c
lib/libssl/ssl_clnt.c
lib/libssl/ssl_locl.h
lib/libssl/ssl_tlsext.c
lib/libssl/tls_internal.h
lib/libssl/tls_key_share.c

index b83a380..54261c5 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: s3_lib.c,v 1.220 2022/01/05 17:10:02 jsing Exp $ */
+/* $OpenBSD: s3_lib.c,v 1.221 2022/01/06 18:23:56 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -1665,35 +1665,17 @@ long
 _SSL_get_peer_tmp_key(SSL *s, EVP_PKEY **key)
 {
        EVP_PKEY *pkey = NULL;
-       SESS_CERT *sc;
        int ret = 0;
 
        *key = NULL;
 
-       if (s->session == NULL || s->session->sess_cert == NULL)
-               return 0;
-
-       sc = s->session->sess_cert;
+       if (S3I(s)->hs.key_share == NULL)
+               goto err;
 
        if ((pkey = EVP_PKEY_new()) == NULL)
-               return 0;
-
-       if (sc->peer_dh_tmp != NULL) {
-               if (!EVP_PKEY_set1_DH(pkey, sc->peer_dh_tmp))
-                       goto err;
-       } else if (sc->peer_ecdh_tmp) {
-               if (!EVP_PKEY_set1_EC_KEY(pkey, sc->peer_ecdh_tmp))
-                       goto err;
-       } else if (sc->peer_x25519_tmp != NULL) {
-               if (!ssl_kex_dummy_ecdhe_x25519(pkey))
-                       goto err;
-       } else if (S3I(s)->hs.key_share != NULL) {
-               if (!tls_key_share_peer_pkey(S3I(s)->hs.key_share,
-                   pkey))
-                       goto err;
-       } else {
                goto err;
-       }
+       if (!tls_key_share_peer_pkey(S3I(s)->hs.key_share, pkey))
+               goto err;
 
        *key = pkey;
        pkey = NULL;
index 3b38820..6eece6d 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: ssl_cert.c,v 1.88 2021/11/29 18:36:27 tb Exp $ */
+/* $OpenBSD: ssl_cert.c,v 1.89 2022/01/06 18:23:56 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -395,10 +395,6 @@ ssl_sess_cert_free(SESS_CERT *sc)
        for (i = 0; i < SSL_PKEY_NUM; i++)
                X509_free(sc->peer_pkeys[i].x509);
 
-       DH_free(sc->peer_dh_tmp);
-       EC_KEY_free(sc->peer_ecdh_tmp);
-       free(sc->peer_x25519_tmp);
-
        free(sc);
 }
 
index 80a16f1..c3912c3 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: ssl_clnt.c,v 1.126 2022/01/04 12:53:31 jsing Exp $ */
+/* $OpenBSD: ssl_clnt.c,v 1.127 2022/01/06 18:23:56 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -1223,20 +1223,23 @@ ssl3_get_server_certificate(SSL *s)
 static int
 ssl3_get_server_kex_dhe(SSL *s, EVP_PKEY **pkey, CBS *cbs)
 {
+       int nid = NID_dhKeyAgreement;
        int invalid_params, invalid_key;
-       SESS_CERT *sc = NULL;
-       DH *dh = NULL;
+       SESS_CERT *sc;
        long alg_a;
 
        alg_a = S3I(s)->hs.cipher->algorithm_auth;
        sc = s->session->sess_cert;
 
-       if ((dh = DH_new()) == NULL)
+       tls_key_share_free(S3I(s)->hs.key_share);
+       if ((S3I(s)->hs.key_share = tls_key_share_new_nid(nid)) == NULL)
                goto err;
 
-       if (!ssl_kex_peer_params_dhe(dh, cbs, &invalid_params))
+       if (!tls_key_share_peer_params(S3I(s)->hs.key_share, cbs,
+           &invalid_params))
                goto decode_err;
-       if (!ssl_kex_peer_public_dhe(dh, cbs, &invalid_key))
+       if (!tls_key_share_peer_public(S3I(s)->hs.key_share, cbs,
+           &invalid_key))
                goto decode_err;
 
        if (invalid_params) {
@@ -1256,72 +1259,12 @@ ssl3_get_server_kex_dhe(SSL *s, EVP_PKEY **pkey, CBS *cbs)
                /* XXX - Anonymous DH, so no certificate or pkey. */
                *pkey = NULL;
 
-       sc->peer_dh_tmp = dh;
-
        return 1;
 
  decode_err:
        SSLerror(s, SSL_R_BAD_PACKET_LENGTH);
        ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
 
- err:
-       DH_free(dh);
-
-       return 0;
-}
-
-static int
-ssl3_get_server_kex_ecdhe_ecp(SSL *s, SESS_CERT *sc, int nid, CBS *public)
-{
-       EC_KEY *ecdh = NULL;
-       int ret = 0;
-
-       /* Extract the server's ephemeral ECDH public key. */
-       if ((ecdh = EC_KEY_new()) == NULL) {
-               SSLerror(s, ERR_R_MALLOC_FAILURE);
-               goto err;
-       }
-       if (!ssl_kex_peer_public_ecdhe_ecp(ecdh, nid, public)) {
-               SSLerror(s, SSL_R_BAD_ECPOINT);
-               ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
-               goto err;
-       }
-
-       sc->peer_nid = nid;
-       sc->peer_ecdh_tmp = ecdh;
-       ecdh = NULL;
-
-       ret = 1;
-
- err:
-       EC_KEY_free(ecdh);
-
-       return (ret);
-}
-
-static int
-ssl3_get_server_kex_ecdhe_ecx(SSL *s, SESS_CERT *sc, int nid, CBS *public)
-{
-       size_t outlen;
-
-       if (nid != NID_X25519) {
-               SSLerror(s, ERR_R_INTERNAL_ERROR);
-               goto err;
-       }
-
-       if (CBS_len(public) != X25519_KEY_LENGTH) {
-               SSLerror(s, SSL_R_BAD_ECPOINT);
-               ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
-               goto err;
-       }
-
-       if (!CBS_stow(public, &sc->peer_x25519_tmp, &outlen)) {
-               SSLerror(s, ERR_R_MALLOC_FAILURE);
-               goto err;
-       }
-
-       return 1;
-
  err:
        return 0;
 }
@@ -1334,7 +1277,6 @@ ssl3_get_server_kex_ecdhe(SSL *s, EVP_PKEY **pkey, CBS *cbs)
        uint16_t curve_id;
        SESS_CERT *sc;
        long alg_a;
-       int nid;
 
        alg_a = S3I(s)->hs.cipher->algorithm_auth;
        sc = s->session->sess_cert;
@@ -1346,8 +1288,8 @@ ssl3_get_server_kex_ecdhe(SSL *s, EVP_PKEY **pkey, CBS *cbs)
 
        /* Only named curves are supported. */
        if (curve_type != NAMED_CURVE_TYPE) {
-               ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
                SSLerror(s, SSL_R_UNSUPPORTED_ELLIPTIC_CURVE);
+               ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
                goto err;
        }
 
@@ -1364,19 +1306,12 @@ ssl3_get_server_kex_ecdhe(SSL *s, EVP_PKEY **pkey, CBS *cbs)
                goto err;
        }
 
-       if ((nid = tls1_ec_curve_id2nid(curve_id)) == 0) {
-               SSLerror(s, SSL_R_UNABLE_TO_FIND_ECDH_PARAMETERS);
-               ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
+       tls_key_share_free(S3I(s)->hs.key_share);
+       if ((S3I(s)->hs.key_share = tls_key_share_new(curve_id)) == NULL)
                goto err;
-       }
 
-       if (nid == NID_X25519) {
-               if (!ssl3_get_server_kex_ecdhe_ecx(s, sc, nid, &public))
-                       goto err;
-       } else {
-               if (!ssl3_get_server_kex_ecdhe_ecp(s, sc, nid, &public))
-                       goto err;
-       }
+       if (!tls_key_share_peer_public(S3I(s)->hs.key_share, &public, NULL))
+               goto err;
 
        /*
         * The ECC/TLS specification does not mention the use of DSA to sign
@@ -1446,16 +1381,7 @@ ssl3_get_server_key_exchange(SSL *s)
                return (1);
        }
 
-       if (s->session->sess_cert != NULL) {
-               DH_free(s->session->sess_cert->peer_dh_tmp);
-               s->session->sess_cert->peer_dh_tmp = NULL;
-
-               EC_KEY_free(s->session->sess_cert->peer_ecdh_tmp);
-               s->session->sess_cert->peer_ecdh_tmp = NULL;
-
-               free(s->session->sess_cert->peer_x25519_tmp);
-               s->session->sess_cert->peer_x25519_tmp = NULL;
-       } else {
+       if (s->session->sess_cert == NULL) {
                s->session->sess_cert = ssl_sess_cert_new();
                if (s->session->sess_cert == NULL)
                        goto err;
@@ -1966,28 +1892,22 @@ ssl3_send_client_kex_rsa(SSL *s, SESS_CERT *sess_cert, CBB *cbb)
 static int
 ssl3_send_client_kex_dhe(SSL *s, SESS_CERT *sess_cert, CBB *cbb)
 {
-       DH *dh_clnt = NULL;
-       DH *dh_srvr;
        uint8_t *key = NULL;
        size_t key_len = 0;
        int ret = 0;
 
        /* Ensure that we have an ephemeral key from the server for DHE. */
-       if ((dh_srvr = sess_cert->peer_dh_tmp) == NULL) {
+       if (S3I(s)->hs.key_share == NULL) {
                ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
                SSLerror(s, SSL_R_UNABLE_TO_FIND_DH_PARAMETERS);
                goto err;
        }
 
-       if ((dh_clnt = DH_new()) == NULL)
+       if (!tls_key_share_generate(S3I(s)->hs.key_share))
                goto err;
-
-       if (!ssl_kex_generate_dhe(dh_clnt, dh_srvr))
+       if (!tls_key_share_public(S3I(s)->hs.key_share, cbb))
                goto err;
-       if (!ssl_kex_public_dhe(dh_clnt, cbb))
-               goto err;
-
-       if (!ssl_kex_derive_dhe(dh_clnt, dh_srvr, &key, &key_len))
+       if (!tls_key_share_derive(S3I(s)->hs.key_share, &key, &key_len))
                goto err;
 
        if (!tls12_derive_master_secret(s, key, key_len))
@@ -1996,38 +1916,37 @@ ssl3_send_client_kex_dhe(SSL *s, SESS_CERT *sess_cert, CBB *cbb)
        ret = 1;
 
  err:
-       DH_free(dh_clnt);
        freezero(key, key_len);
 
        return ret;
 }
 
 static int
-ssl3_send_client_kex_ecdhe_ecp(SSL *s, SESS_CERT *sc, CBB *cbb)
+ssl3_send_client_kex_ecdhe(SSL *s, SESS_CERT *sc, CBB *cbb)
 {
-       EC_KEY *ecdh = NULL;
        uint8_t *key = NULL;
        size_t key_len = 0;
+       CBB public;
        int ret = 0;
-       CBB ecpoint;
 
-       if ((ecdh = EC_KEY_new()) == NULL) {
-               SSLerror(s, ERR_R_MALLOC_FAILURE);
+       /* Ensure that we have an ephemeral key for ECDHE. */
+       if (S3I(s)->hs.key_share == NULL) {
+               ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
+               SSLerror(s, ERR_R_INTERNAL_ERROR);
                goto err;
        }
 
-       if (!ssl_kex_generate_ecdhe_ecp(ecdh, sc->peer_nid))
+       if (!tls_key_share_generate(S3I(s)->hs.key_share))
                goto err;
 
-       /* Encode our public key. */
-       if (!CBB_add_u8_length_prefixed(cbb, &ecpoint))
-               goto err;
-       if (!ssl_kex_public_ecdhe_ecp(ecdh, &ecpoint))
+       if (!CBB_add_u8_length_prefixed(cbb, &public))
+               return 0;
+       if (!tls_key_share_public(S3I(s)->hs.key_share, &public))
                goto err;
        if (!CBB_flush(cbb))
                goto err;
 
-       if (!ssl_kex_derive_ecdhe_ecp(ecdh, sc->peer_ecdh_tmp, &key, &key_len))
+       if (!tls_key_share_derive(S3I(s)->hs.key_share, &key, &key_len))
                goto err;
 
        if (!tls12_derive_master_secret(s, key, key_len))
@@ -2037,71 +1956,10 @@ ssl3_send_client_kex_ecdhe_ecp(SSL *s, SESS_CERT *sc, CBB *cbb)
 
  err:
        freezero(key, key_len);
-       EC_KEY_free(ecdh);
-
-       return ret;
-}
-
-static int
-ssl3_send_client_kex_ecdhe_ecx(SSL *s, SESS_CERT *sc, CBB *cbb)
-{
-       uint8_t *public_key = NULL, *private_key = NULL, *shared_key = NULL;
-       int ret = 0;
-       CBB ecpoint;
-
-       /* Generate X25519 key pair and derive shared key. */
-       if ((public_key = malloc(X25519_KEY_LENGTH)) == NULL)
-               goto err;
-       if ((private_key = malloc(X25519_KEY_LENGTH)) == NULL)
-               goto err;
-       if ((shared_key = malloc(X25519_KEY_LENGTH)) == NULL)
-               goto err;
-       X25519_keypair(public_key, private_key);
-       if (!X25519(shared_key, private_key, sc->peer_x25519_tmp))
-               goto err;
-
-       /* Serialize the public key. */
-       if (!CBB_add_u8_length_prefixed(cbb, &ecpoint))
-               goto err;
-       if (!CBB_add_bytes(&ecpoint, public_key, X25519_KEY_LENGTH))
-               goto err;
-       if (!CBB_flush(cbb))
-               goto err;
-
-       if (!tls12_derive_master_secret(s, shared_key, X25519_KEY_LENGTH))
-               goto err;
-
-       ret = 1;
-
- err:
-       free(public_key);
-       freezero(private_key, X25519_KEY_LENGTH);
-       freezero(shared_key, X25519_KEY_LENGTH);
 
        return ret;
 }
 
-static int
-ssl3_send_client_kex_ecdhe(SSL *s, SESS_CERT *sc, CBB *cbb)
-{
-       if (sc->peer_x25519_tmp != NULL) {
-               if (ssl3_send_client_kex_ecdhe_ecx(s, sc, cbb) != 1)
-                       goto err;
-       } else if (sc->peer_ecdh_tmp != NULL) {
-               if (ssl3_send_client_kex_ecdhe_ecp(s, sc, cbb) != 1)
-                       goto err;
-       } else {
-               ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
-               SSLerror(s, ERR_R_INTERNAL_ERROR);
-               goto err;
-       }
-
-       return 1;
-
- err:
-       return 0;
-}
-
 static int
 ssl3_send_client_kex_gost(SSL *s, SESS_CERT *sess_cert, CBB *cbb)
 {
@@ -2627,7 +2485,7 @@ ssl3_check_cert_and_algorithm(SSL *s)
        long             alg_k, alg_a;
        EVP_PKEY        *pkey = NULL;
        SESS_CERT       *sc;
-       DH              *dh;
+       int nid = NID_undef;
 
        alg_k = S3I(s)->hs.cipher->algorithm_mkey;
        alg_a = S3I(s)->hs.cipher->algorithm_auth;
@@ -2641,7 +2499,9 @@ ssl3_check_cert_and_algorithm(SSL *s)
                SSLerror(s, ERR_R_INTERNAL_ERROR);
                goto err;
        }
-       dh = s->session->sess_cert->peer_dh_tmp;
+
+       if (S3I(s)->hs.key_share != NULL)
+               nid = tls_key_share_nid(S3I(s)->hs.key_share);
 
        /* This is the passed certificate. */
 
@@ -2670,7 +2530,7 @@ ssl3_check_cert_and_algorithm(SSL *s)
                goto fatal_err;
        }
        if ((alg_k & SSL_kDHE) &&
-           !(has_bits(i, EVP_PK_DH|EVP_PKT_EXCH) || (dh != NULL))) {
+           !(has_bits(i, EVP_PK_DH|EVP_PKT_EXCH) || (nid == NID_dhKeyAgreement))) {
                SSLerror(s, SSL_R_MISSING_DH_KEY);
                goto fatal_err;
        }
index d6d20c2..83b40d2 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: ssl_locl.h,v 1.373 2022/01/05 17:10:02 jsing Exp $ */
+/* $OpenBSD: ssl_locl.h,v 1.374 2022/01/06 18:23:56 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -1242,11 +1242,6 @@ typedef struct sess_cert_st {
        /* Obviously we don't have the private keys of these,
         * so maybe we shouldn't even use the CERT_PKEY type here. */
 
-       int peer_nid;
-       DH *peer_dh_tmp;
-       EC_KEY *peer_ecdh_tmp;
-       uint8_t *peer_x25519_tmp;
-
        int references; /* actually always 1 at the moment */
 } SESS_CERT;
 
index 4cc4065..71955d9 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: ssl_tlsext.c,v 1.104 2022/01/05 17:10:02 jsing Exp $ */
+/* $OpenBSD: ssl_tlsext.c,v 1.105 2022/01/06 18:23:56 jsing Exp $ */
 /*
  * Copyright (c) 2016, 2017, 2019 Joel Sing <jsing@openbsd.org>
  * Copyright (c) 2017 Doug Hogan <doug@openbsd.org>
@@ -1510,11 +1510,10 @@ tlsext_keyshare_server_parse(SSL *s, uint16_t msg_type, CBS *cbs, int *alert)
                        continue;
 
                /* Decode and store the selected key share. */
-               S3I(s)->hs.key_share = tls_key_share_new(group);
-               if (S3I(s)->hs.key_share == NULL)
+               if ((S3I(s)->hs.key_share = tls_key_share_new(group)) == NULL)
                        goto err;
                if (!tls_key_share_peer_public(S3I(s)->hs.key_share,
-                   group, &key_exchange))
+                   &key_exchange, NULL))
                        goto err;
        }
 
@@ -1568,7 +1567,7 @@ tlsext_keyshare_client_parse(SSL *s, uint16_t msg_type, CBS *cbs, int *alert)
 
        /* Unpack server share. */
        if (!CBS_get_u16(cbs, &group))
-               goto err;
+               return 0;
 
        if (CBS_len(cbs) == 0) {
                /* HRR does not include an actual key share, only the group. */
@@ -1584,16 +1583,13 @@ tlsext_keyshare_client_parse(SSL *s, uint16_t msg_type, CBS *cbs, int *alert)
 
        if (S3I(s)->hs.key_share == NULL)
                return 0;
-
+       if (tls_key_share_group(S3I(s)->hs.key_share) != group)
+               return 0;
        if (!tls_key_share_peer_public(S3I(s)->hs.key_share,
-           group, &key_exchange))
-               goto err;
+           &key_exchange, NULL))
+               return 0;
 
        return 1;
-
- err:
-       *alert = SSL_AD_DECODE_ERROR;
-       return 0;
 }
 
 /*
index 87c7f3b..7e2bead 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: tls_internal.h,v 1.2 2022/01/05 17:10:03 jsing Exp $ */
+/* $OpenBSD: tls_internal.h,v 1.3 2022/01/06 18:23:56 jsing Exp $ */
 /*
  * Copyright (c) 2018, 2019, 2021 Joel Sing <jsing@openbsd.org>
  *
@@ -63,11 +63,14 @@ struct tls_key_share *tls_key_share_new_nid(int nid);
 void tls_key_share_free(struct tls_key_share *ks);
 
 uint16_t tls_key_share_group(struct tls_key_share *ks);
+int tls_key_share_nid(struct tls_key_share *ks);
 int tls_key_share_peer_pkey(struct tls_key_share *ks, EVP_PKEY *pkey);
 int tls_key_share_generate(struct tls_key_share *ks);
 int tls_key_share_public(struct tls_key_share *ks, CBB *cbb);
-int tls_key_share_peer_public(struct tls_key_share *ks, uint16_t group,
-    CBS *cbs);
+int tls_key_share_peer_params(struct tls_key_share *ks, CBS *cbs,
+    int *invalid_params);
+int tls_key_share_peer_public(struct tls_key_share *ks, CBS *cbs,
+    int *invalid_key);
 int tls_key_share_derive(struct tls_key_share *ks, uint8_t **shared_key,
     size_t *shared_key_len);
 
index 1bce651..6e390f4 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: tls_key_share.c,v 1.1 2022/01/05 17:10:03 jsing Exp $ */
+/* $OpenBSD: tls_key_share.c,v 1.2 2022/01/06 18:23:56 jsing Exp $ */
 /*
  * Copyright (c) 2020 Joel Sing <jsing@openbsd.org>
  *
@@ -28,6 +28,9 @@ struct tls_key_share {
        int nid;
        uint16_t group_id;
 
+       DH *dhe;
+       DH *dhe_peer;
+
        EC_KEY *ecdhe;
        EC_KEY *ecdhe_peer;
 
@@ -36,14 +39,10 @@ struct tls_key_share {
        uint8_t *x25519_peer_public;
 };
 
-struct tls_key_share *
-tls_key_share_new(uint16_t group_id)
+static struct tls_key_share *
+tls_key_share_new_internal(int nid, uint16_t group_id)
 {
        struct tls_key_share *ks;
-       int nid;
-
-       if ((nid = tls1_ec_curve_id2nid(group_id)) == 0)
-               return NULL;
 
        if ((ks = calloc(1, sizeof(struct tls_key_share))) == NULL)
                return NULL;
@@ -55,14 +54,27 @@ tls_key_share_new(uint16_t group_id)
 }
 
 struct tls_key_share *
-tls_key_share_new_nid(int nid)
+tls_key_share_new(uint16_t group_id)
 {
-       uint16_t group_id;
+       int nid;
 
-       if ((group_id = tls1_ec_nid2curve_id(nid)) == 0)
+       if ((nid = tls1_ec_curve_id2nid(group_id)) == 0)
                return NULL;
 
-       return tls_key_share_new(group_id);
+       return tls_key_share_new_internal(nid, group_id);
+}
+
+struct tls_key_share *
+tls_key_share_new_nid(int nid)
+{
+       uint16_t group_id = 0;
+
+       if (nid != NID_dhKeyAgreement) {
+               if ((group_id = tls1_ec_nid2curve_id(nid)) == 0)
+                       return NULL;
+       }
+
+       return tls_key_share_new_internal(nid, group_id);
 }
 
 void
@@ -71,6 +83,9 @@ tls_key_share_free(struct tls_key_share *ks)
        if (ks == NULL)
                return;
 
+       DH_free(ks->dhe);
+       DH_free(ks->dhe_peer);
+
        EC_KEY_free(ks->ecdhe);
        EC_KEY_free(ks->ecdhe_peer);
 
@@ -87,20 +102,34 @@ tls_key_share_group(struct tls_key_share *ks)
        return ks->group_id;
 }
 
+int
+tls_key_share_nid(struct tls_key_share *ks)
+{
+       return ks->nid;
+}
+
 int
 tls_key_share_peer_pkey(struct tls_key_share *ks, EVP_PKEY *pkey)
 {
-       if (ks->nid == NID_X25519 && ks->x25519_peer_public != NULL) {
-               if (!ssl_kex_dummy_ecdhe_x25519(pkey))
-                       return 0;
-       } else if (ks->ecdhe_peer != NULL) {
-               if (!EVP_PKEY_set1_EC_KEY(pkey, ks->ecdhe_peer))
-                       return 0;
-       } else {
+       if (ks->nid == NID_dhKeyAgreement && ks->dhe_peer != NULL)
+               return EVP_PKEY_set1_DH(pkey, ks->dhe_peer);
+
+       if (ks->nid == NID_X25519 && ks->x25519_peer_public != NULL)
+               return ssl_kex_dummy_ecdhe_x25519(pkey);
+
+       if (ks->ecdhe_peer != NULL)
+               return EVP_PKEY_set1_EC_KEY(pkey, ks->ecdhe_peer);
+
+       return 0;
+}
+
+static int
+tls_key_share_generate_dhe(struct tls_key_share *ks)
+{
+       if (ks->dhe == NULL)
                return 0;
-       }
 
-       return 1;
+       return ssl_kex_generate_dhe(ks->dhe, ks->dhe);
 }
 
 static int
@@ -161,12 +190,24 @@ tls_key_share_generate_x25519(struct tls_key_share *ks)
 int
 tls_key_share_generate(struct tls_key_share *ks)
 {
+       if (ks->nid == NID_dhKeyAgreement)
+               return tls_key_share_generate_dhe(ks);
+
        if (ks->nid == NID_X25519)
                return tls_key_share_generate_x25519(ks);
 
        return tls_key_share_generate_ecdhe_ecp(ks);
 }
 
+static int
+tls_key_share_public_dhe(struct tls_key_share *ks, CBB *cbb)
+{
+       if (ks->dhe == NULL)
+               return 0;
+
+       return ssl_kex_public_dhe(ks->dhe, cbb);
+}
+
 static int
 tls_key_share_public_ecdhe_ecp(struct tls_key_share *ks, CBB *cbb)
 {
@@ -188,12 +229,52 @@ tls_key_share_public_x25519(struct tls_key_share *ks, CBB *cbb)
 int
 tls_key_share_public(struct tls_key_share *ks, CBB *cbb)
 {
+       if (ks->nid == NID_dhKeyAgreement)
+               return tls_key_share_public_dhe(ks, cbb);
+
        if (ks->nid == NID_X25519)
                return tls_key_share_public_x25519(ks, cbb);
 
        return tls_key_share_public_ecdhe_ecp(ks, cbb);
 }
 
+static int
+tls_key_share_peer_params_dhe(struct tls_key_share *ks, CBS *cbs,
+    int *invalid_params)
+{
+       if (ks->dhe != NULL || ks->dhe_peer != NULL)
+               return 0;
+
+       if ((ks->dhe_peer = DH_new()) == NULL)
+               return 0;
+       if (!ssl_kex_peer_params_dhe(ks->dhe_peer, cbs, invalid_params))
+               return 0;
+       if ((ks->dhe = DHparams_dup(ks->dhe_peer)) == NULL)
+               return 0;
+
+       return 1;
+}
+
+int
+tls_key_share_peer_params(struct tls_key_share *ks, CBS *cbs,
+    int *invalid_params)
+{
+       if (ks->nid != NID_dhKeyAgreement)
+               return 0;
+
+       return tls_key_share_peer_params_dhe(ks, cbs, invalid_params);
+}
+
+static int
+tls_key_share_peer_public_dhe(struct tls_key_share *ks, CBS *cbs,
+    int *invalid_key)
+{
+       if (ks->dhe_peer == NULL)
+               return 0;
+
+       return ssl_kex_peer_public_dhe(ks->dhe_peer, cbs, invalid_key);
+}
+
 static int
 tls_key_share_peer_public_ecdhe_ecp(struct tls_key_share *ks, CBS *cbs)
 {
@@ -234,21 +315,29 @@ tls_key_share_peer_public_x25519(struct tls_key_share *ks, CBS *cbs)
 }
 
 int
-tls_key_share_peer_public(struct tls_key_share *ks, uint16_t group,
-    CBS *cbs)
+tls_key_share_peer_public(struct tls_key_share *ks, CBS *cbs, int *invalid_key)
 {
-       if (ks->group_id != group)
-               return 0;
+       if (invalid_key != NULL)
+               *invalid_key = 0;
 
-       if (ks->nid == NID_X25519) {
-               if (!tls_key_share_peer_public_x25519(ks, cbs))
-                       return 0;
-       } else {
-               if (!tls_key_share_peer_public_ecdhe_ecp(ks, cbs))
-                       return 0;
-       }
+       if (ks->nid == NID_dhKeyAgreement)
+               return tls_key_share_peer_public_dhe(ks, cbs, invalid_key);
 
-       return 1;
+       if (ks->nid == NID_X25519)
+               return tls_key_share_peer_public_x25519(ks, cbs);
+
+       return tls_key_share_peer_public_ecdhe_ecp(ks, cbs);
+}
+
+static int
+tls_key_share_derive_dhe(struct tls_key_share *ks,
+    uint8_t **shared_key, size_t *shared_key_len)
+{
+       if (ks->dhe == NULL || ks->dhe_peer == NULL)
+               return 0;
+
+       return ssl_kex_derive_dhe(ks->dhe, ks->dhe_peer, shared_key,
+           shared_key_len);
 }
 
 static int
@@ -298,6 +387,10 @@ tls_key_share_derive(struct tls_key_share *ks, uint8_t **shared_key,
 
        *shared_key_len = 0;
 
+       if (ks->nid == NID_dhKeyAgreement)
+               return tls_key_share_derive_dhe(ks, shared_key,
+                   shared_key_len);
+
        if (ks->nid == NID_X25519)
                return tls_key_share_derive_x25519(ks, shared_key,
                    shared_key_len);