use HMAC-MD5, HMAC-SHA1 and AES Key Wrap sys/crypto/
authordamien <damien@openbsd.org>
Tue, 12 Aug 2008 15:59:40 +0000 (15:59 +0000)
committerdamien <damien@openbsd.org>
Tue, 12 Aug 2008 15:59:40 +0000 (15:59 +0000)
sys/net80211/ieee80211_crypto.c

index 4cdd07c..801323d 100644 (file)
@@ -1,4 +1,4 @@
-/*     $OpenBSD: ieee80211_crypto.c,v 1.43 2008/07/21 19:27:26 damien Exp $    */
+/*     $OpenBSD: ieee80211_crypto.c,v 1.44 2008/08/12 15:59:40 damien Exp $    */
 
 /*-
  * Copyright (c) 2008 Damien Bergamini <damien.bergamini@free.fr>
 #include <crypto/arc4.h>
 #include <crypto/md5.h>
 #include <crypto/sha1.h>
+#include <crypto/sha2.h>
+#include <crypto/hmac.h>
 #include <crypto/rijndael.h>
+#include <crypto/key_wrap.h>
 
-/* similar to iovec except that it accepts const pointers */
-struct vector {
-       const void      *base;
-       size_t          len;
-};
-
-void   ieee80211_prf(const u_int8_t *, size_t, struct vector *, int,
-           u_int8_t *, size_t);
+void   ieee80211_prf(const u_int8_t *, size_t, const u_int8_t *, size_t,
+           const u_int8_t *, size_t, u_int8_t *, size_t);
 void   ieee80211_derive_pmkid(const u_int8_t *, size_t, const u_int8_t *,
            const u_int8_t *, u_int8_t *);
 
@@ -227,189 +224,33 @@ ieee80211_decrypt(struct ieee80211com *ic, struct mbuf *m0,
        return m0;
 }
 
-/*
- * AES Key Wrap (see RFC 3394).
- */
-static const u_int8_t aes_key_wrap_iv[8] =
-       { 0xa6, 0xa6, 0xa6, 0xa6, 0xa6, 0xa6, 0xa6, 0xa6 };
-
-static void
-aes_key_wrap(const u_int8_t *kek, size_t kek_len, const u_int8_t *pt,
-    size_t len, u_int8_t *ct)
-{
-       rijndael_ctx ctx;
-       u_int8_t *a, *r, ar[16];
-       u_int64_t t, b[2];
-       size_t i;
-       int j;
-
-       /* allow ciphertext and plaintext to overlap (ct == pt) */
-       ovbcopy(pt, ct + 8, len * 8);
-       a = ct;
-       memcpy(a, aes_key_wrap_iv, 8);  /* default IV */
-
-       rijndael_set_key_enc_only(&ctx, kek, kek_len * 8);
-
-       for (j = 0, t = 1; j < 6; j++) {
-               r = ct + 8;
-               for (i = 0; i < len; i++, t++) {
-                       memcpy(ar, a, 8);
-                       memcpy(ar + 8, r, 8);
-                       rijndael_encrypt(&ctx, ar, (u_int8_t *)b);
-                       b[0] ^= htobe64(t);
-                       memcpy(a, &b[0], 8);
-                       memcpy(r, &b[1], 8);
-
-                       r += 8;
-               }
-       }
-}
-
-static int
-aes_key_unwrap(const u_int8_t *kek, size_t kek_len, const u_int8_t *ct,
-    u_int8_t *pt, size_t len)
-{
-       rijndael_ctx ctx;
-       u_int8_t a[8], *r, b[16];
-       u_int64_t t, ar[2];
-       size_t i;
-       int j;
-
-       memcpy(a, ct, 8);
-       /* allow ciphertext and plaintext to overlap (ct == pt) */
-       ovbcopy(ct + 8, pt, len * 8);
-
-       rijndael_set_key(&ctx, kek, kek_len * 8);
-
-       for (j = 0, t = 6 * len; j < 6; j++) {
-               r = pt + (len - 1) * 8;
-               for (i = 0; i < len; i++, t--) {
-                       memcpy(&ar[0], a, 8);
-                       ar[0] ^= htobe64(t);
-                       memcpy(&ar[1], r, 8);
-                       rijndael_decrypt(&ctx, (u_int8_t *)ar, b);
-                       memcpy(a, b, 8);
-                       memcpy(r, b + 8, 8);
-
-                       r -= 8;
-               }
-       }
-       return memcmp(a, aes_key_wrap_iv, 8) != 0;
-}
-
-/*
- * HMAC-MD5 (see RFC 2104).
- */
-static void
-hmac_md5(const struct vector *vec, int vcnt, const u_int8_t *key,
-    size_t key_len, u_int8_t digest[MD5_DIGEST_LENGTH])
-{
-       MD5_CTX ctx;
-       u_int8_t k_pad[MD5_BLOCK_LENGTH];
-       u_int8_t tk[MD5_DIGEST_LENGTH];
-       int i;
-
-       if (key_len > MD5_BLOCK_LENGTH) {
-               MD5Init(&ctx);
-               MD5Update(&ctx, key, key_len);
-               MD5Final(tk, &ctx);
-
-               key = tk;
-               key_len = MD5_DIGEST_LENGTH;
-       }
-
-       bzero(k_pad, sizeof k_pad);
-       bcopy(key, k_pad, key_len);
-       for (i = 0; i < MD5_BLOCK_LENGTH; i++)
-               k_pad[i] ^= 0x36;
-
-       MD5Init(&ctx);
-       MD5Update(&ctx, k_pad, MD5_BLOCK_LENGTH);
-       for (i = 0; i < vcnt; i++)
-               MD5Update(&ctx, vec[i].base, vec[i].len);
-       MD5Final(digest, &ctx);
-
-       bzero(k_pad, sizeof k_pad);
-       bcopy(key, k_pad, key_len);
-       for (i = 0; i < MD5_BLOCK_LENGTH; i++)
-               k_pad[i] ^= 0x5c;
-
-       MD5Init(&ctx);
-       MD5Update(&ctx, k_pad, MD5_BLOCK_LENGTH);
-       MD5Update(&ctx, digest, MD5_DIGEST_LENGTH);
-       MD5Final(digest, &ctx);
-}
-
-/*
- * HMAC-SHA1 (see RFC 2104).
- */
-static void
-hmac_sha1(const struct vector *vec, int vcnt, const u_int8_t *key,
-    size_t key_len, u_int8_t digest[SHA1_DIGEST_LENGTH])
-{
-       SHA1_CTX ctx;
-       u_int8_t k_pad[SHA1_BLOCK_LENGTH];
-       u_int8_t tk[SHA1_DIGEST_LENGTH];
-       int i;
-
-       if (key_len > SHA1_BLOCK_LENGTH) {
-               SHA1Init(&ctx);
-               SHA1Update(&ctx, key, key_len);
-               SHA1Final(tk, &ctx);
-
-               key = tk;
-               key_len = SHA1_DIGEST_LENGTH;
-       }
-
-       bzero(k_pad, sizeof k_pad);
-       bcopy(key, k_pad, key_len);
-       for (i = 0; i < SHA1_BLOCK_LENGTH; i++)
-               k_pad[i] ^= 0x36;
-
-       SHA1Init(&ctx);
-       SHA1Update(&ctx, k_pad, SHA1_BLOCK_LENGTH);
-       for (i = 0; i < vcnt; i++)
-               SHA1Update(&ctx, vec[i].base, vec[i].len);
-       SHA1Final(digest, &ctx);
-
-       bzero(k_pad, sizeof k_pad);
-       bcopy(key, k_pad, key_len);
-       for (i = 0; i < SHA1_BLOCK_LENGTH; i++)
-               k_pad[i] ^= 0x5c;
-
-       SHA1Init(&ctx);
-       SHA1Update(&ctx, k_pad, SHA1_BLOCK_LENGTH);
-       SHA1Update(&ctx, digest, SHA1_DIGEST_LENGTH);
-       SHA1Final(digest, &ctx);
-}
-
 /*
  * SHA1-based Pseudo-Random Function (see 8.5.1.1).
  */
 void
-ieee80211_prf(const u_int8_t *key, size_t key_len, struct vector *vec,
-    int vcnt, u_int8_t *output, size_t len)
+ieee80211_prf(const u_int8_t *key, size_t key_len, const u_int8_t *label,
+    size_t label_len, const u_int8_t *context, size_t context_len,
+    u_int8_t *output, size_t len)
 {
-       u_int8_t hash[SHA1_DIGEST_LENGTH];
-       u_int8_t count = 0;
-
-       /* single octet count, starts at 0 */
-       vec[vcnt].base = &count;
-       vec[vcnt].len  = 1;
-       vcnt++;
-
-       while (len >= SHA1_DIGEST_LENGTH) {
-               hmac_sha1(vec, vcnt, key, key_len, output);
-               count++;
-
+       HMAC_SHA1_CTX ctx;
+       u_int8_t digest[SHA1_DIGEST_LENGTH];
+       u_int8_t count;
+
+       for (count = 0; len != 0; count++) {
+               HMAC_SHA1_Init(&ctx, key, key_len);
+               HMAC_SHA1_Update(&ctx, label, label_len);
+               HMAC_SHA1_Update(&ctx, context, context_len);
+               HMAC_SHA1_Update(&ctx, &count, 1);
+               if (len < SHA1_DIGEST_LENGTH) {
+                       HMAC_SHA1_Final(digest, &ctx);
+                       /* truncate HMAC-SHA1 to len bytes */
+                       memcpy(output, digest, len);
+                       break;
+               }
+               HMAC_SHA1_Final(output, &ctx);
                output += SHA1_DIGEST_LENGTH;
                len -= SHA1_DIGEST_LENGTH;
        }
-       if (len > 0) {
-               hmac_sha1(vec, vcnt, key, key_len, hash);
-               /* truncate HMAC-SHA1 to len bytes */
-               memcpy(output, hash, len);
-       }
 }
 
 /*
@@ -420,29 +261,21 @@ ieee80211_derive_ptk(const u_int8_t *pmk, size_t pmk_len, const u_int8_t *aa,
     const u_int8_t *spa, const u_int8_t *anonce, const u_int8_t *snonce,
     u_int8_t *ptk, size_t ptk_len)
 {
-       struct vector vec[6];   /* +1 for PRF */
+       u_int8_t buf[2 * IEEE80211_ADDR_LEN + 2 * EAPOL_KEY_NONCE_LEN];
        int ret;
 
-       vec[0].base = "Pairwise key expansion";
-       vec[0].len  = 23;       /* include trailing '\0' */
-
+       /* Min(AA,SPA) || Max(AA,SPA) */
        ret = memcmp(aa, spa, IEEE80211_ADDR_LEN) < 0;
-       /* Min(AA,SPA) */
-       vec[1].base = ret ? aa : spa;
-       vec[1].len  = IEEE80211_ADDR_LEN;
-       /* Max(AA,SPA) */
-       vec[2].base = ret ? spa : aa;
-       vec[2].len  = IEEE80211_ADDR_LEN;
+       memcpy(&buf[ 0], ret ? aa : spa, IEEE80211_ADDR_LEN);
+       memcpy(&buf[ 6], ret ? spa : aa, IEEE80211_ADDR_LEN);
 
+       /* Min(ANonce,SNonce) || Max(ANonce,SNonce) */
        ret = memcmp(anonce, snonce, EAPOL_KEY_NONCE_LEN) < 0;
-       /* Min(ANonce,SNonce) */
-       vec[3].base = ret ? anonce : snonce;
-       vec[3].len  = EAPOL_KEY_NONCE_LEN;
-       /* Max(ANonce,SNonce) */
-       vec[4].base = ret ? snonce : anonce;
-       vec[4].len  = EAPOL_KEY_NONCE_LEN;
-
-       ieee80211_prf(pmk, pmk_len, vec, 5, ptk, ptk_len);
+       memcpy(&buf[12], ret ? anonce : snonce, EAPOL_KEY_NONCE_LEN);
+       memcpy(&buf[44], ret ? snonce : anonce, EAPOL_KEY_NONCE_LEN);
+
+       ieee80211_prf(pmk, pmk_len, "Pairwise key expansion", 23,
+           buf, sizeof buf, ptk, ptk_len);
 }
 
 /*
@@ -452,21 +285,23 @@ void
 ieee80211_derive_pmkid(const u_int8_t *pmk, size_t pmk_len, const u_int8_t *aa,
     const u_int8_t *spa, u_int8_t *pmkid)
 {
-       struct vector vec[3];
-       u_int8_t hash[SHA1_DIGEST_LENGTH];
-
-       vec[0].base = "PMK Name";
-       vec[0].len  = 8;        /* does *not* include trailing '\0' */
-       vec[1].base = aa;
-       vec[1].len  = IEEE80211_ADDR_LEN;
-       vec[2].base = spa;
-       vec[2].len  = IEEE80211_ADDR_LEN;
-
-       hmac_sha1(vec, 3, pmk, pmk_len, hash);
+       HMAC_SHA1_CTX ctx;
+       u_int8_t digest[SHA1_DIGEST_LENGTH];
+
+       HMAC_SHA1_Init(&ctx, pmk, pmk_len);
+       HMAC_SHA1_Update(&ctx, "PMK Name", 8);
+       HMAC_SHA1_Update(&ctx, aa, IEEE80211_ADDR_LEN);
+       HMAC_SHA1_Update(&ctx, spa, IEEE80211_ADDR_LEN);
+       HMAC_SHA1_Final(digest, &ctx);
        /* use the first 128 bits of the HMAC-SHA1 */
-       memcpy(pmkid, hash, IEEE80211_PMKID_LEN);
+       memcpy(pmkid, digest, IEEE80211_PMKID_LEN);
 }
 
+typedef union _ANY_CTX {
+       HMAC_MD5_CTX    md5;
+       HMAC_SHA1_CTX   sha1;
+} ANY_CTX;
+
 /*
  * Compute the Key MIC field of an EAPOL-Key frame using the specified Key
  * Confirmation Key (KCK).  The hash function can be either HMAC-MD5 or
@@ -475,20 +310,24 @@ ieee80211_derive_pmkid(const u_int8_t *pmk, size_t pmk_len, const u_int8_t *aa,
 void
 ieee80211_eapol_key_mic(struct ieee80211_eapol_key *key, const u_int8_t *kck)
 {
-       u_int8_t hash[SHA1_DIGEST_LENGTH];
-       struct vector vec;
+       u_int8_t digest[SHA1_DIGEST_LENGTH];
+       ANY_CTX ctx;    /* XXX off stack? */
+       u_int len;
 
-       vec.base = key;
-       vec.len  = BE_READ_2(key->len) + 4;
+       len = BE_READ_2(key->len) + 4;
 
        switch (BE_READ_2(key->info) & EAPOL_KEY_VERSION_MASK) {
        case EAPOL_KEY_DESC_V1:
-               hmac_md5(&vec, 1, kck, 16, key->mic);
+               HMAC_MD5_Init(&ctx.md5, kck, 16);
+               HMAC_MD5_Update(&ctx.md5, (u_int8_t *)key, len);
+               HMAC_MD5_Final(key->mic, &ctx.md5);
                break;
        case EAPOL_KEY_DESC_V2:
-               hmac_sha1(&vec, 1, kck, 16, hash);
+               HMAC_SHA1_Init(&ctx.sha1, kck, 16);
+               HMAC_SHA1_Update(&ctx.sha1, (u_int8_t *)key, len);
+               HMAC_SHA1_Final(digest, &ctx.sha1);
                /* truncate HMAC-SHA1 to its 128 MSBs */
-               memcpy(key->mic, hash, EAPOL_KEY_MIC_LEN);
+               memcpy(key->mic, digest, EAPOL_KEY_MIC_LEN);
                break;
        }
 }
@@ -519,7 +358,10 @@ void
 ieee80211_eapol_key_encrypt(struct ieee80211com *ic,
     struct ieee80211_eapol_key *key, const u_int8_t *kek)
 {
-       struct rc4_ctx ctx;
+       union {
+               struct rc4_ctx rc4;
+               aes_key_wrap_ctx aes;
+       } ctx;  /* XXX off stack? */
        u_int8_t keybuf[EAPOL_KEY_IV_LEN + 16];
        u_int16_t len, info;
        u_int8_t *data;
@@ -540,10 +382,10 @@ ieee80211_eapol_key_encrypt(struct ieee80211com *ic,
                memcpy(keybuf, key->iv, EAPOL_KEY_IV_LEN);
                memcpy(keybuf + EAPOL_KEY_IV_LEN, kek, 16);
 
-               rc4_keysetup(&ctx, keybuf, sizeof keybuf);
+               rc4_keysetup(&ctx.rc4, keybuf, sizeof keybuf);
                /* discard the first 256 octets of the ARC4 key stream */
-               rc4_skip(&ctx, RC4STATE);
-               rc4_crypt(&ctx, data, data, len);
+               rc4_skip(&ctx.rc4, RC4STATE);
+               rc4_crypt(&ctx.rc4, data, data, len);
                break;
        case EAPOL_KEY_DESC_V2:
                if (len < 16 || (len & 7) != 0) {
@@ -553,7 +395,8 @@ ieee80211_eapol_key_encrypt(struct ieee80211com *ic,
                        memset(&data[len], 0, n - 1);
                        len += n - 1;
                }
-               aes_key_wrap(kek, 16, data, len / 8, data);
+               aes_key_wrap_set_key_wrap_only(&ctx.aes, kek, 16);
+               aes_key_wrap(&ctx.aes, data, len / 8, data);
                len += 8;       /* AES Key Wrap adds 8 bytes */
                /* update key data length */
                BE_WRITE_2(key->paylen, len);
@@ -572,7 +415,10 @@ int
 ieee80211_eapol_key_decrypt(struct ieee80211_eapol_key *key,
     const u_int8_t *kek)
 {
-       struct rc4_ctx ctx;
+       union {
+               struct rc4_ctx rc4;
+               aes_key_wrap_ctx aes;
+       } ctx;  /* XXX off stack? */
        u_int8_t keybuf[EAPOL_KEY_IV_LEN + 16];
        u_int16_t len, info;
        u_int8_t *data;
@@ -587,17 +433,18 @@ ieee80211_eapol_key_decrypt(struct ieee80211_eapol_key *key,
                memcpy(keybuf, key->iv, EAPOL_KEY_IV_LEN);
                memcpy(keybuf + EAPOL_KEY_IV_LEN, kek, 16);
 
-               rc4_keysetup(&ctx, keybuf, sizeof keybuf);
+               rc4_keysetup(&ctx.rc4, keybuf, sizeof keybuf);
                /* discard the first 256 octets of the ARC4 key stream */
-               rc4_skip(&ctx, RC4STATE);
-               rc4_crypt(&ctx, data, data, len);
+               rc4_skip(&ctx.rc4, RC4STATE);
+               rc4_crypt(&ctx.rc4, data, data, len);
                return 0;
        case EAPOL_KEY_DESC_V2:
                /* Key Data Length must be a multiple of 8 */
                if (len < 16 + 8 || (len & 7) != 0)
                        return 1;
                len -= 8;       /* AES Key Wrap adds 8 bytes */
-               return aes_key_unwrap(kek, 16, data, data, len / 8);
+               aes_key_wrap_set_key(&ctx.aes, kek, 16);
+               return aes_key_unwrap(&ctx.aes, data, data, len / 8);
        }
 
        return 1;       /* unknown Key Descriptor Version */