Convert ssl3_get_server_key_exchange() to CBS.
authorjsing <jsing@openbsd.org>
Thu, 16 Aug 2018 17:39:50 +0000 (17:39 +0000)
committerjsing <jsing@openbsd.org>
Thu, 16 Aug 2018 17:39:50 +0000 (17:39 +0000)
ok inoguchi@ tb@

lib/libssl/ssl_clnt.c

index 83b2c1b..c53fbda 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: ssl_clnt.c,v 1.29 2018/08/14 16:31:02 jsing Exp $ */
+/* $OpenBSD: ssl_clnt.c,v 1.30 2018/08/16 17:39:50 jsing Exp $ */
 /* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)
  * All rights reserved.
  *
@@ -1189,9 +1189,9 @@ err:
 }
 
 static int
-ssl3_get_server_kex_dhe(SSL *s, EVP_PKEY **pkey, unsigned char **pp, long *nn)
+ssl3_get_server_kex_dhe(SSL *s, EVP_PKEY **pkey, CBS *cbs)
 {
-       CBS cbs, dhp, dhg, dhpk;
+       CBS dhp, dhg, dhpk;
        BN_CTX *bn_ctx = NULL;
        SESS_CERT *sc = NULL;
        DH *dh = NULL;
@@ -1201,31 +1201,26 @@ ssl3_get_server_kex_dhe(SSL *s, EVP_PKEY **pkey, unsigned char **pp, long *nn)
        alg_a = S3I(s)->hs.new_cipher->algorithm_auth;
        sc = SSI(s)->sess_cert;
 
-       if (*nn < 0)
-               goto err;
-
-       CBS_init(&cbs, *pp, *nn);
-
        if ((dh = DH_new()) == NULL) {
                SSLerror(s, ERR_R_DH_LIB);
                goto err;
        }
 
-       if (!CBS_get_u16_length_prefixed(&cbs, &dhp))
+       if (!CBS_get_u16_length_prefixed(cbs, &dhp))
                goto truncated;
        if ((dh->p = BN_bin2bn(CBS_data(&dhp), CBS_len(&dhp), NULL)) == NULL) {
                SSLerror(s, ERR_R_BN_LIB);
                goto err;
        }
 
-       if (!CBS_get_u16_length_prefixed(&cbs, &dhg))
+       if (!CBS_get_u16_length_prefixed(cbs, &dhg))
                goto truncated;
        if ((dh->g = BN_bin2bn(CBS_data(&dhg), CBS_len(&dhg), NULL)) == NULL) {
                SSLerror(s, ERR_R_BN_LIB);
                goto err;
        }
 
-       if (!CBS_get_u16_length_prefixed(&cbs, &dhpk))
+       if (!CBS_get_u16_length_prefixed(cbs, &dhpk))
                goto truncated;
        if ((dh->pub_key = BN_bin2bn(CBS_data(&dhpk), CBS_len(&dhpk),
            NULL)) == NULL) {
@@ -1250,9 +1245,6 @@ ssl3_get_server_kex_dhe(SSL *s, EVP_PKEY **pkey, unsigned char **pp, long *nn)
 
        sc->peer_dh_tmp = dh;
 
-       *nn = CBS_len(&cbs);
-       *pp = (unsigned char *)CBS_data(&cbs);
-
        return (1);
 
  truncated:
@@ -1353,9 +1345,9 @@ ssl3_get_server_kex_ecdhe_ecx(SSL *s, SESS_CERT *sc, int nid, CBS *public)
 }
 
 static int
-ssl3_get_server_kex_ecdhe(SSL *s, EVP_PKEY **pkey, unsigned char **pp, long *nn)
+ssl3_get_server_kex_ecdhe(SSL *s, EVP_PKEY **pkey, CBS *cbs)
 {
-       CBS cbs, public;
+       CBS public;
        uint8_t curve_type;
        uint16_t curve_id;
        SESS_CERT *sc;
@@ -1366,15 +1358,10 @@ ssl3_get_server_kex_ecdhe(SSL *s, EVP_PKEY **pkey, unsigned char **pp, long *nn)
        alg_a = S3I(s)->hs.new_cipher->algorithm_auth;
        sc = SSI(s)->sess_cert;
 
-       if (*nn < 0)
-               goto err;
-
-       CBS_init(&cbs, *pp, *nn);
-
        /* Only named curves are supported. */
-       if (!CBS_get_u8(&cbs, &curve_type) ||
+       if (!CBS_get_u8(cbs, &curve_type) ||
            curve_type != NAMED_CURVE_TYPE ||
-           !CBS_get_u16(&cbs, &curve_id)) {
+           !CBS_get_u16(cbs, &curve_id)) {
                al = SSL_AD_DECODE_ERROR;
                SSLerror(s, SSL_R_LENGTH_TOO_SHORT);
                goto f_err;
@@ -1396,7 +1383,7 @@ ssl3_get_server_kex_ecdhe(SSL *s, EVP_PKEY **pkey, unsigned char **pp, long *nn)
                goto f_err;
        }
 
-       if (!CBS_get_u8_length_prefixed(&cbs, &public))
+       if (!CBS_get_u8_length_prefixed(cbs, &public))
                goto truncated;
 
        if (nid == NID_X25519) {
@@ -1420,9 +1407,6 @@ ssl3_get_server_kex_ecdhe(SSL *s, EVP_PKEY **pkey, unsigned char **pp, long *nn)
                /* XXX - Anonymous ECDH, so no certificate or pkey. */
                *pkey = NULL;
 
-       *nn = CBS_len(&cbs);
-       *pp = (unsigned char *)CBS_data(&cbs);
-
        return (1);
 
  truncated:
@@ -1439,12 +1423,17 @@ ssl3_get_server_kex_ecdhe(SSL *s, EVP_PKEY **pkey, unsigned char **pp, long *nn)
 int
 ssl3_get_server_key_exchange(SSL *s)
 {
-       EVP_MD_CTX       md_ctx;
-       unsigned char   *param, *p;
-       int              al, i, j, param_len, ok;
-       long             n, alg_k, alg_a;
-       EVP_PKEY        *pkey = NULL;
-       const            EVP_MD *md = NULL;
+       CBS cbs, signature;
+       const EVP_MD *md = NULL;
+       EVP_PKEY *pkey = NULL;
+       EVP_MD_CTX md_ctx;
+       const unsigned char *param;
+       uint8_t hash_id, sig_id;
+       long n, alg_k, alg_a;
+       int al, ok, sigalg;
+       size_t param_len;
+
+       EVP_MD_CTX_init(&md_ctx);
 
        alg_k = S3I(s)->hs.new_cipher->algorithm_mkey;
        alg_a = S3I(s)->hs.new_cipher->algorithm_auth;
@@ -1458,7 +1447,10 @@ ssl3_get_server_key_exchange(SSL *s)
        if (!ok)
                return ((int)n);
 
-       EVP_MD_CTX_init(&md_ctx);
+       if (n < 0)
+               goto err;
+
+       CBS_init(&cbs, s->internal->init_msg, n);
 
        if (S3I(s)->tmp.message_type != SSL3_MT_SERVER_KEY_EXCHANGE) {
                /*
@@ -1491,14 +1483,14 @@ ssl3_get_server_key_exchange(SSL *s)
                        goto err;
        }
 
-       param = p = (unsigned char *)s->internal->init_msg;
-       param_len = n;
+       param = CBS_data(&cbs);
+       param_len = CBS_len(&cbs);
 
        if (alg_k & SSL_kDHE) {
-               if (ssl3_get_server_kex_dhe(s, &pkey, &p, &n) != 1)
+               if (ssl3_get_server_kex_dhe(s, &pkey, &cbs) != 1)
                        goto err;
        } else if (alg_k & SSL_kECDHE) {
-               if (ssl3_get_server_kex_ecdhe(s, &pkey, &p, &n) != 1)
+               if (ssl3_get_server_kex_ecdhe(s, &pkey, &cbs) != 1)
                        goto err;
        } else if (alg_k != 0) {
                al = SSL_AD_UNEXPECTED_MESSAGE;
@@ -1506,47 +1498,42 @@ ssl3_get_server_key_exchange(SSL *s)
                        goto f_err;
        }
 
-       param_len = param_len - n;
+       param_len -= CBS_len(&cbs);
 
        /* if it was signed, check the signature */
        if (pkey != NULL) {
                if (SSL_USE_SIGALGS(s)) {
-                       int sigalg = tls12_get_sigid(pkey);
-                       if (sigalg == -1) {
+                       if (!CBS_get_u8(&cbs, &hash_id))
+                               goto truncated;
+                       if (!CBS_get_u8(&cbs, &sig_id))
+                               goto truncated;
+
+                       if ((md = tls12_get_hash(hash_id)) == NULL) {
+                               SSLerror(s, SSL_R_UNKNOWN_DIGEST);
+                               al = SSL_AD_DECODE_ERROR;
+                               goto f_err;
+                       }
+
+                       /* Check key type is consistent with signature. */
+                       if ((sigalg = tls12_get_sigid(pkey)) == -1) {
                                /* Should never happen */
                                SSLerror(s, ERR_R_INTERNAL_ERROR);
                                goto err;
                        }
-                       /* Check key type is consistent with signature. */
-                       if (2 > n)
-                               goto truncated;
-                       if (sigalg != (int)p[1]) {
+                       if (sigalg != sig_id) {
                                SSLerror(s, SSL_R_WRONG_SIGNATURE_TYPE);
                                al = SSL_AD_DECODE_ERROR;
                                goto f_err;
                        }
-                       md = tls12_get_hash(p[0]);
-                       if (md == NULL) {
-                               SSLerror(s, SSL_R_UNKNOWN_DIGEST);
-                               al = SSL_AD_DECODE_ERROR;
-                               goto f_err;
-                       }
-                       p += 2;
-                       n -= 2;
                } else if (pkey->type == EVP_PKEY_RSA) {
                        md = EVP_md5_sha1();
                } else {
                        md = EVP_sha1();
                }
 
-               if (2 > n)
+               if (!CBS_get_u16_length_prefixed(&cbs, &signature))
                        goto truncated;
-               n2s(p, i);
-               n -= 2;
-               j = EVP_PKEY_size(pkey);
-
-               if (i != n || n > j) {
-                       /* wrong packet length */
+               if (CBS_len(&signature) > EVP_PKEY_size(pkey)) {
                        al = SSL_AD_DECODE_ERROR;
                        SSLerror(s, SSL_R_WRONG_SIGNATURE_LENGTH);
                        goto f_err;
@@ -1562,8 +1549,8 @@ ssl3_get_server_key_exchange(SSL *s)
                        goto err;
                if (!EVP_VerifyUpdate(&md_ctx, param, param_len))
                        goto err;
-               if (EVP_VerifyFinal(&md_ctx, p, (int)n, pkey) <= 0) {
-                       /* bad signature */
+               if (EVP_VerifyFinal(&md_ctx, CBS_data(&signature),
+                   CBS_len(&signature), pkey) <= 0) {
                        al = SSL_AD_DECRYPT_ERROR;
                        SSLerror(s, SSL_R_BAD_SIGNATURE);
                        goto f_err;
@@ -1574,12 +1561,12 @@ ssl3_get_server_key_exchange(SSL *s)
                        SSLerror(s, ERR_R_INTERNAL_ERROR);
                        goto err;
                }
-               /* still data left over */
-               if (n != 0) {
-                       al = SSL_AD_DECODE_ERROR;
-                       SSLerror(s, SSL_R_EXTRA_DATA_IN_MESSAGE);
-                       goto f_err;
-               }
+       }
+
+       if (CBS_len(&cbs) != 0) {
+               al = SSL_AD_DECODE_ERROR;
+               SSLerror(s, SSL_R_EXTRA_DATA_IN_MESSAGE);
+               goto f_err;
        }
 
        EVP_PKEY_free(pkey);
@@ -1588,7 +1575,6 @@ ssl3_get_server_key_exchange(SSL *s)
        return (1);
 
  truncated:
-       /* wrong packet length */
        al = SSL_AD_DECODE_ERROR;
        SSLerror(s, SSL_R_BAD_PACKET_LENGTH);