wycheproof: use EVP_MD instead of importing "hash"
authortb <tb@openbsd.org>
Thu, 6 Apr 2023 08:38:53 +0000 (08:38 +0000)
committertb <tb@openbsd.org>
Thu, 6 Apr 2023 08:38:53 +0000 (08:38 +0000)
regress/lib/libcrypto/wycheproof/wycheproof.go

index b3c9225..0698ac9 100644 (file)
@@ -1,4 +1,4 @@
-/* $OpenBSD: wycheproof.go,v 1.141 2023/03/25 09:21:17 tb Exp $ */
+/* $OpenBSD: wycheproof.go,v 1.142 2023/04/06 08:38:53 tb Exp $ */
 /*
  * Copyright (c) 2018 Joel Sing <jsing@openbsd.org>
  * Copyright (c) 2018,2019,2022 Theo Buehler <tb@openbsd.org>
@@ -75,14 +75,10 @@ import "C"
 
 import (
        "bytes"
-       "crypto/sha1"
-       "crypto/sha256"
-       "crypto/sha512"
        "encoding/base64"
        "encoding/hex"
        "encoding/json"
        "fmt"
-       "hash"
        "io/ioutil"
        "log"
        "os"
@@ -564,23 +560,6 @@ func nidFromString(ns string) (int, error) {
        return -1, fmt.Errorf("unknown NID %q", ns)
 }
 
-func hashFromString(hs string) (hash.Hash, error) {
-       switch hs {
-       case "SHA-1":
-               return sha1.New(), nil
-       case "SHA-224":
-               return sha256.New224(), nil
-       case "SHA-256":
-               return sha256.New(), nil
-       case "SHA-384":
-               return sha512.New384(), nil
-       case "SHA-512":
-               return sha512.New(), nil
-       default:
-               return nil, fmt.Errorf("unknown hash %q", hs)
-       }
-}
-
 func hashEvpMdFromString(hs string) (*C.EVP_MD, error) {
        switch hs {
        case "SHA-1":
@@ -598,6 +577,26 @@ func hashEvpMdFromString(hs string) (*C.EVP_MD, error) {
        }
 }
 
+func hashEvpDigestMessage(md *C.EVP_MD, msg []byte) ([]byte, C.size, error) {
+       size := C.EVP_MD_size(md)
+       if size <= 0 || size > C.EVP_MAX_MD_SIZE {
+               return nil, 0, fmt.Errorf("unexpected MD size %d", size)
+       }
+
+       msgLen := len(msg)
+       if msgLen == 0 {
+               msg = append(msg, 0)
+       }
+
+       digest := make([]byte, size)
+
+       if C.EVP_Digest(unsafe.Pointer(&msg[0]), C.size_t(msgLen), (*C.uchar)(unsafe.Pointer(&digest[0])), nil, md, nil) != 1 {
+               return nil, 0, fmt.Errorf("EVP_Digest failed")
+       }
+
+       return digest, int(size), nil
+}
+
 func checkAesCbcPkcs5(ctx *C.EVP_CIPHER_CTX, doEncrypt int, key []byte, keyLen int,
        iv []byte, ivLen int, in []byte, inLen int, out []byte, outLen int,
        wt *wycheproofTestAesCbcPkcs5) bool {
@@ -1337,19 +1336,15 @@ func encodeDSAP1363Sig(wtSig string) (*C.uchar, C.int) {
        return cDer, derLen
 }
 
-func runDSATest(dsa *C.DSA, variant testVariant, h hash.Hash, wt *wycheproofTestDSA) bool {
+func runDSATest(dsa *C.DSA, md *C.EVP_MD, variant testVariant, wt *wycheproofTestDSA) bool {
        msg, err := hex.DecodeString(wt.Msg)
        if err != nil {
                log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
        }
 
-       h.Reset()
-       h.Write(msg)
-       msg = h.Sum(nil)
-
-       msgLen := len(msg)
-       if msgLen == 0 {
-               msg = append(msg, 0)
+       msg, msgLen, err := hashEvpDigestMessage(md, msg)
+       if err != nil {
+               log.Fatalf("%v", err)
        }
 
        var ret C.int
@@ -1433,7 +1428,7 @@ func runDSATestGroup(algorithm string, variant testVariant, wtg *wycheproofTestG
                log.Fatalf("DSA_set0_key returned %d", ret)
        }
 
-       h, err := hashFromString(wtg.SHA)
+       md, err := hashEvpMdFromString(wtg.SHA)
        if err != nil {
                log.Fatalf("Failed to get hash: %v", err)
        }
@@ -1475,13 +1470,13 @@ func runDSATestGroup(algorithm string, variant testVariant, wtg *wycheproofTestG
 
        success := true
        for _, wt := range wtg.Tests {
-               if !runDSATest(dsa, variant, h, wt) {
+               if !runDSATest(dsa, md, variant, wt) {
                        success = false
                }
-               if !runDSATest(dsaDER, variant, h, wt) {
+               if !runDSATest(dsaDER, md, variant, wt) {
                        success = false
                }
-               if !runDSATest(dsaPEM, variant, h, wt) {
+               if !runDSATest(dsaPEM, md, variant, wt) {
                        success = false
                }
        }
@@ -1722,19 +1717,15 @@ func runECDHWebCryptoTestGroup(algorithm string, wtg *wycheproofTestGroupECDHWeb
        return success
 }
 
-func runECDSATest(ecKey *C.EC_KEY, nid int, h hash.Hash, variant testVariant, wt *wycheproofTestECDSA) bool {
+func runECDSATest(ecKey *C.EC_KEY, md *C.EVP_MD, nid int, variant testVariant, wt *wycheproofTestECDSA) bool {
        msg, err := hex.DecodeString(wt.Msg)
        if err != nil {
                log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
        }
 
-       h.Reset()
-       h.Write(msg)
-       msg = h.Sum(nil)
-
-       msgLen := len(msg)
-       if msgLen == 0 {
-               msg = append(msg, 0)
+       msg, msgLen, err := hashEvpDigestMessage(md, msg)
+       if err != nil {
+               log.Fatalf("%v", err)
        }
 
        var ret C.int
@@ -1810,14 +1801,14 @@ func runECDSATestGroup(algorithm string, variant testVariant, wtg *wycheproofTes
        if err != nil {
                log.Fatalf("Failed to get MD NID: %v", err)
        }
-       h, err := hashFromString(wtg.SHA)
+       md, err := hashEvpMdFromString(wtg.SHA)
        if err != nil {
                log.Fatalf("Failed to get hash: %v", err)
        }
 
        success := true
        for _, wt := range wtg.Tests {
-               if !runECDSATest(ecKey, nid, h, variant, wt) {
+               if !runECDSATest(ecKey, md, nid, variant, wt) {
                        success = false
                }
        }
@@ -1914,14 +1905,14 @@ func runECDSAWebCryptoTestGroup(algorithm string, wtg *wycheproofTestGroupECDSAW
        if err != nil {
                log.Fatalf("Failed to get MD NID: %v", err)
        }
-       h, err := hashFromString(wtg.SHA)
+       md, err := hashEvpMdFromString(wtg.SHA)
        if err != nil {
                log.Fatalf("Failed to get hash: %v", err)
        }
 
        success := true
        for _, wt := range wtg.Tests {
-               if !runECDSATest(ecKey, nid, h, Webcrypto, wt) {
+               if !runECDSATest(ecKey, md, nid, Webcrypto, wt) {
                        success = false
                }
        }
@@ -2512,25 +2503,23 @@ func runRsaesPkcs1TestGroup(algorithm string, wtg *wycheproofTestGroupRsaesPkcs1
        return success
 }
 
-func runRsassaTest(rsa *C.RSA, h hash.Hash, sha *C.EVP_MD, mgfSha *C.EVP_MD, sLen int, wt *wycheproofTestRsassa) bool {
+func runRsassaTest(rsa *C.RSA, sha *C.EVP_MD, mgfSha *C.EVP_MD, sLen int, wt *wycheproofTestRsassa) bool {
        msg, err := hex.DecodeString(wt.Msg)
        if err != nil {
                log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
        }
 
-       h.Reset()
-       h.Write(msg)
-       msg = h.Sum(nil)
+       msg, _, err = hashEvpDigestMessage(sha, msg)
+       if err != nil {
+               log.Fatalf("%v", err)
+       }
 
        sig, err := hex.DecodeString(wt.Sig)
        if err != nil {
                log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
        }
 
-       msgLen, sigLen := len(msg), len(sig)
-       if msgLen == 0 {
-               msg = append(msg, 0)
-       }
+       sigLen := len(sig)
        if sigLen == 0 {
                sig = append(sig, 0)
        }
@@ -2599,11 +2588,6 @@ func runRsassaTestGroup(algorithm string, wtg *wycheproofTestGroupRsassa) bool {
        rsaN = nil
        rsaE = nil
 
-       h, err := hashFromString(wtg.SHA)
-       if err != nil {
-               log.Fatalf("Failed to get hash: %v", err)
-       }
-
        sha, err := hashEvpMdFromString(wtg.SHA)
        if err != nil {
                log.Fatalf("Failed to get hash: %v", err)
@@ -2616,32 +2600,30 @@ func runRsassaTestGroup(algorithm string, wtg *wycheproofTestGroupRsassa) bool {
 
        success := true
        for _, wt := range wtg.Tests {
-               if !runRsassaTest(rsa, h, sha, mgfSha, wtg.SLen, wt) {
+               if !runRsassaTest(rsa, sha, mgfSha, wtg.SLen, wt) {
                        success = false
                }
        }
        return success
 }
 
-func runRSATest(rsa *C.RSA, nid int, h hash.Hash, wt *wycheproofTestRSA) bool {
+func runRSATest(rsa *C.RSA, md *C.EVP_MD, nid int, wt *wycheproofTestRSA) bool {
        msg, err := hex.DecodeString(wt.Msg)
        if err != nil {
                log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
        }
 
-       h.Reset()
-       h.Write(msg)
-       msg = h.Sum(nil)
+       msg, msgLen, err := hashEvpDigestMessage(md, msg)
+       if err != nil {
+               log.Fatalf("%v", err)
+       }
 
        sig, err := hex.DecodeString(wt.Sig)
        if err != nil {
                log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
        }
 
-       msgLen, sigLen := len(msg), len(sig)
-       if msgLen == 0 {
-               msg = append(msg, 0)
-       }
+       sigLen := len(sig)
        if sigLen == 0 {
                sig = append(sig, 0)
        }
@@ -2695,14 +2677,14 @@ func runRSATestGroup(algorithm string, wtg *wycheproofTestGroupRSA) bool {
        if err != nil {
                log.Fatalf("Failed to get MD NID: %v", err)
        }
-       h, err := hashFromString(wtg.SHA)
+       md, err := hashEvpMdFromString(wtg.SHA)
        if err != nil {
                log.Fatalf("Failed to get hash: %v", err)
        }
 
        success := true
        for _, wt := range wtg.Tests {
-               if !runRSATest(rsa, nid, h, wt) {
+               if !runRSATest(rsa, md, nid, wt) {
                        success = false
                }
        }