refactor and simplify sshkey_read()
authordjm <djm@openbsd.org>
Fri, 28 Oct 2022 00:38:58 +0000 (00:38 +0000)
committerdjm <djm@openbsd.org>
Fri, 28 Oct 2022 00:38:58 +0000 (00:38 +0000)
feedback/ok markus@

usr.bin/ssh/sshkey.c

index e383c62..ea52028 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: sshkey.c,v 1.125 2022/10/28 00:37:24 djm Exp $ */
+/* $OpenBSD: sshkey.c,v 1.126 2022/10/28 00:38:58 djm Exp $ */
 /*
  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
  * Copyright (c) 2008 Alexander von Gernler.  All rights reserved.
@@ -171,12 +171,20 @@ sshkey_impl_from_type_nid(int type, int nid)
        return NULL;
 }
 
+static const struct sshkey_impl *
+sshkey_impl_from_key(const struct sshkey *k)
+{
+       if (k == NULL)
+               return NULL;
+       return sshkey_impl_from_type_nid(k->type, k->ecdsa_nid);
+}
+
 const char *
 sshkey_type(const struct sshkey *k)
 {
        const struct sshkey_impl *impl;
 
-       if ((impl = sshkey_impl_from_type(k->type)) == NULL)
+       if ((impl = sshkey_impl_from_key(k)) == NULL)
                return "unknown";
        return impl->shortname;
 }
@@ -355,7 +363,7 @@ sshkey_size(const struct sshkey *k)
 {
        const struct sshkey_impl *impl;
 
-       if ((impl = sshkey_impl_from_type_nid(k->type, k->ecdsa_nid)) == NULL)
+       if ((impl = sshkey_impl_from_key(k)) == NULL)
                return 0;
        if (impl->funcs->size != NULL)
                return impl->funcs->size(k);
@@ -578,8 +586,8 @@ sshkey_sk_cleanup(struct sshkey *k)
        k->sk_key_handle = k->sk_reserved = NULL;
 }
 
-void
-sshkey_free(struct sshkey *k)
+static void
+sshkey_free_contents(struct sshkey *k)
 {
        const struct sshkey_impl *impl;
 
@@ -592,6 +600,12 @@ sshkey_free(struct sshkey *k)
                cert_free(k->cert);
        freezero(k->shielded_private, k->shielded_len);
        freezero(k->shield_prekey, k->shield_prekey_len);
+}
+
+void
+sshkey_free(struct sshkey *k)
+{
+       sshkey_free_contents(k);
        freezero(k, sizeof(*k));
 }
 
@@ -1105,29 +1119,8 @@ sshkey_read(struct sshkey *ret, char **cpp)
 
        if (ret == NULL)
                return SSH_ERR_INVALID_ARGUMENT;
-
-       switch (ret->type) {
-       case KEY_UNSPEC:
-       case KEY_RSA:
-       case KEY_DSA:
-       case KEY_ECDSA:
-       case KEY_ECDSA_SK:
-       case KEY_ED25519:
-       case KEY_ED25519_SK:
-       case KEY_DSA_CERT:
-       case KEY_ECDSA_CERT:
-       case KEY_ECDSA_SK_CERT:
-       case KEY_RSA_CERT:
-       case KEY_ED25519_CERT:
-       case KEY_ED25519_SK_CERT:
-#ifdef WITH_XMSS
-       case KEY_XMSS:
-       case KEY_XMSS_CERT:
-#endif /* WITH_XMSS */
-               break; /* ok */
-       default:
+       if (ret->type != KEY_UNSPEC && sshkey_impl_from_type(ret->type) == NULL)
                return SSH_ERR_INVALID_ARGUMENT;
-       }
 
        /* Decode type */
        cp = *cpp;
@@ -1180,98 +1173,9 @@ sshkey_read(struct sshkey *ret, char **cpp)
        }
 
        /* Fill in ret from parsed key */
-       ret->type = type;
-       if (sshkey_is_cert(ret)) {
-               if (!sshkey_is_cert(k)) {
-                       sshkey_free(k);
-                       return SSH_ERR_EXPECTED_CERT;
-               }
-               if (ret->cert != NULL)
-                       cert_free(ret->cert);
-               ret->cert = k->cert;
-               k->cert = NULL;
-       }
-       switch (sshkey_type_plain(ret->type)) {
-#ifdef WITH_OPENSSL
-       case KEY_RSA:
-               RSA_free(ret->rsa);
-               ret->rsa = k->rsa;
-               k->rsa = NULL;
-#ifdef DEBUG_PK
-               RSA_print_fp(stderr, ret->rsa, 8);
-#endif
-               break;
-       case KEY_DSA:
-               DSA_free(ret->dsa);
-               ret->dsa = k->dsa;
-               k->dsa = NULL;
-#ifdef DEBUG_PK
-               DSA_print_fp(stderr, ret->dsa, 8);
-#endif
-               break;
-       case KEY_ECDSA:
-               EC_KEY_free(ret->ecdsa);
-               ret->ecdsa = k->ecdsa;
-               ret->ecdsa_nid = k->ecdsa_nid;
-               k->ecdsa = NULL;
-               k->ecdsa_nid = -1;
-#ifdef DEBUG_PK
-               sshkey_dump_ec_key(ret->ecdsa);
-#endif
-               break;
-       case KEY_ECDSA_SK:
-               EC_KEY_free(ret->ecdsa);
-               ret->ecdsa = k->ecdsa;
-               ret->ecdsa_nid = k->ecdsa_nid;
-               ret->sk_application = k->sk_application;
-               k->ecdsa = NULL;
-               k->ecdsa_nid = -1;
-               k->sk_application = NULL;
-#ifdef DEBUG_PK
-               sshkey_dump_ec_key(ret->ecdsa);
-               fprintf(stderr, "App: %s\n", ret->sk_application);
-#endif
-               break;
-#endif /* WITH_OPENSSL */
-       case KEY_ED25519:
-               freezero(ret->ed25519_pk, ED25519_PK_SZ);
-               ret->ed25519_pk = k->ed25519_pk;
-               k->ed25519_pk = NULL;
-#ifdef DEBUG_PK
-               /* XXX */
-#endif
-               break;
-       case KEY_ED25519_SK:
-               freezero(ret->ed25519_pk, ED25519_PK_SZ);
-               ret->ed25519_pk = k->ed25519_pk;
-               ret->sk_application = k->sk_application;
-               k->ed25519_pk = NULL;
-               k->sk_application = NULL;
-               break;
-#ifdef WITH_XMSS
-       case KEY_XMSS:
-               free(ret->xmss_pk);
-               ret->xmss_pk = k->xmss_pk;
-               k->xmss_pk = NULL;
-               free(ret->xmss_state);
-               ret->xmss_state = k->xmss_state;
-               k->xmss_state = NULL;
-               free(ret->xmss_name);
-               ret->xmss_name = k->xmss_name;
-               k->xmss_name = NULL;
-               free(ret->xmss_filename);
-               ret->xmss_filename = k->xmss_filename;
-               k->xmss_filename = NULL;
-#ifdef DEBUG_PK
-               /* XXX */
-#endif
-               break;
-#endif /* WITH_XMSS */
-       default:
-               sshkey_free(k);
-               return SSH_ERR_INTERNAL_ERROR;
-       }
-       sshkey_free(k);
+       sshkey_free_contents(ret);
+       *ret = *k;
+       freezero(k, sizeof(*k));
 
        /* success */
        *cpp = cp;