Rework the MD setting in the RSA ASN.1 method
authortb <tb@openbsd.org>
Thu, 26 Oct 2023 07:57:54 +0000 (07:57 +0000)
committertb <tb@openbsd.org>
Thu, 26 Oct 2023 07:57:54 +0000 (07:57 +0000)
This streamlines the code to use safer idioms, do proper error checking
and be slightly less convoluted. Sprinkle a few references to RFC 8017
and explain better what we are doing and why. Clarify ownership and use
more consistent style.

This removes the last internal use of X509_ALGOR_set_md().

ok jsing

lib/libcrypto/rsa/rsa_ameth.c

index ae38c20..43f52f7 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: rsa_ameth.c,v 1.33 2023/08/12 08:02:43 tb Exp $ */
+/* $OpenBSD: rsa_ameth.c,v 1.34 2023/10/26 07:57:54 tb Exp $ */
 /* Written by Dr Stephen N Henson (steve@openssl.org) for the OpenSSL
  * project 2006.
  */
@@ -72,6 +72,7 @@
 #include "cryptlib.h"
 #include "evp_local.h"
 #include "rsa_local.h"
+#include "x509_local.h"
 
 #ifndef OPENSSL_NO_CMS
 static int rsa_cms_sign(CMS_SignerInfo *si);
@@ -574,45 +575,82 @@ rsa_pkey_ctrl(EVP_PKEY *pkey, int op, long arg1, void *arg2)
        return 1;
 }
 
-/* Allocate and set algorithm ID from EVP_MD, defaults to SHA1. */
 static int
-rsa_md_to_algor(X509_ALGOR **palg, const EVP_MD *md)
+rsa_md_to_algor(const EVP_MD *md, X509_ALGOR **out_alg)
 {
+       X509_ALGOR *alg = NULL;
+       int ret = 0;
+
+       X509_ALGOR_free(*out_alg);
+       *out_alg = NULL;
+
+       /* RFC 8017 - default hash is SHA-1 and hence omitted. */
        if (md == NULL || EVP_MD_type(md) == NID_sha1)
-               return 1;
-       *palg = X509_ALGOR_new();
-       if (*palg == NULL)
-               return 0;
-       X509_ALGOR_set_md(*palg, md);
-       return 1;
+               goto done;
+
+       if ((alg = X509_ALGOR_new()) == NULL)
+               goto err;
+       if (!X509_ALGOR_set_evp_md(alg, md))
+               goto err;
+
+ done:
+       *out_alg = alg;
+       alg = NULL;
+
+       ret = 1;
+
+ err:
+       X509_ALGOR_free(alg);
+
+       return ret;
 }
 
-/* Allocate and set MGF1 algorithm ID from EVP_MD. */
+/*
+ * RFC 8017, A.2.1 and A.2.3 - encode maskGenAlgorithm for RSAES-OAEP
+ * and RSASSA-PSS. The default is mgfSHA1 and hence omitted.
+ */
 static int
-rsa_md_to_mgf1(X509_ALGOR **palg, const EVP_MD *mgf1md)
+rsa_mgf1md_to_maskGenAlgorithm(const EVP_MD *mgf1md, X509_ALGOR **out_alg)
 {
-       X509_ALGOR *algtmp = NULL;
-       ASN1_STRING *stmp = NULL;
+       X509_ALGOR *alg = NULL;
+       X509_ALGOR *inner_alg = NULL;
+       ASN1_STRING *astr = NULL;
+       ASN1_OBJECT *aobj;
+       int ret = 0;
+
+       X509_ALGOR_free(*out_alg);
+       *out_alg = NULL;
 
-       *palg = NULL;
        if (mgf1md == NULL || EVP_MD_type(mgf1md) == NID_sha1)
-               return 1;
-       /* need to embed algorithm ID inside another */
-       if (!rsa_md_to_algor(&algtmp, mgf1md))
+               goto done;
+
+       if ((inner_alg = X509_ALGOR_new()) == NULL)
                goto err;
-       if (ASN1_item_pack(algtmp, &X509_ALGOR_it, &stmp) == NULL)
-                goto err;
-       *palg = X509_ALGOR_new();
-       if (*palg == NULL)
+       if (!X509_ALGOR_set_evp_md(inner_alg, mgf1md))
+               goto err;
+       if ((astr = ASN1_item_pack(inner_alg, &X509_ALGOR_it, NULL)) == NULL)
+               goto err;
+
+       if ((alg = X509_ALGOR_new()) == NULL)
                goto err;
-       X509_ALGOR_set0(*palg, OBJ_nid2obj(NID_mgf1), V_ASN1_SEQUENCE, stmp);
-       stmp = NULL;
+       if ((aobj = OBJ_nid2obj(NID_mgf1)) == NULL)
+               goto err;
+       if (!X509_ALGOR_set0(alg, aobj, V_ASN1_SEQUENCE, astr))
+               goto err;
+       astr = NULL;
+
+ done:
+       *out_alg = alg;
+       alg = NULL;
+
+       ret = 1;
+
  err:
-       ASN1_STRING_free(stmp);
-       X509_ALGOR_free(algtmp);
-       if (*palg)
-               return 1;
-       return 0;
+       X509_ALGOR_free(alg);
+       X509_ALGOR_free(inner_alg);
+       ASN1_STRING_free(astr);
+
+       return ret;
 }
 
 /* Convert algorithm ID to EVP_MD, defaults to SHA1. */
@@ -662,28 +700,36 @@ rsa_ctx_to_pss(EVP_PKEY_CTX *pkctx)
 RSA_PSS_PARAMS *
 rsa_pss_params_create(const EVP_MD *sigmd, const EVP_MD *mgf1md, int saltlen)
 {
-       RSA_PSS_PARAMS *pss = RSA_PSS_PARAMS_new();
+       RSA_PSS_PARAMS *pss = NULL;
 
-       if (pss == NULL)
+       if (mgf1md == NULL)
+               mgf1md = sigmd;
+
+       if ((pss = RSA_PSS_PARAMS_new()) == NULL)
+               goto err;
+
+       if (!rsa_md_to_algor(sigmd, &pss->hashAlgorithm))
+               goto err;
+       if (!rsa_mgf1md_to_maskGenAlgorithm(mgf1md, &pss->maskGenAlgorithm))
+               goto err;
+
+       /* Translate mgf1md to X509_ALGOR in decoded form for internal use. */
+       if (!rsa_md_to_algor(mgf1md, &pss->maskHash))
                goto err;
-       if (saltlen != 20) {
-               pss->saltLength = ASN1_INTEGER_new();
-               if (pss->saltLength == NULL)
+
+       /* RFC 8017, A.2.3 - default saltLength is SHA_DIGEST_LENGTH. */
+       if (saltlen != SHA_DIGEST_LENGTH) {
+               if ((pss->saltLength = ASN1_INTEGER_new()) == NULL)
                        goto err;
                if (!ASN1_INTEGER_set(pss->saltLength, saltlen))
                        goto err;
        }
-       if (!rsa_md_to_algor(&pss->hashAlgorithm, sigmd))
-               goto err;
-       if (mgf1md == NULL)
-               mgf1md = sigmd;
-       if (!rsa_md_to_mgf1(&pss->maskGenAlgorithm, mgf1md))
-               goto err;
-       if (!rsa_md_to_algor(&pss->maskHash, mgf1md))
-               goto err;
+
        return pss;
+
  err:
        RSA_PSS_PARAMS_free(pss);
+
        return NULL;
 }
 
@@ -1035,13 +1081,17 @@ rsa_cms_encrypt(CMS_RecipientInfo *ri)
        labellen = EVP_PKEY_CTX_get0_rsa_oaep_label(pkctx, &label);
        if (labellen < 0)
                goto err;
-       oaep = RSA_OAEP_PARAMS_new();
-       if (oaep == NULL)
+
+       if ((oaep = RSA_OAEP_PARAMS_new()) == NULL)
                goto err;
-       if (!rsa_md_to_algor(&oaep->hashFunc, md))
+
+       if (!rsa_md_to_algor(md, &oaep->hashFunc))
                goto err;
-       if (!rsa_md_to_mgf1(&oaep->maskGenFunc, mgf1md))
+       if (!rsa_mgf1md_to_maskGenAlgorithm(mgf1md, &oaep->maskGenFunc))
                goto err;
+
+       /* XXX - why do we not set oaep->maskHash here? */
+
        if (labellen > 0) {
                ASN1_OCTET_STRING *los;
                oaep->pSourceFunc = X509_ALGOR_new();