-/* $OpenBSD: wycheproof.go,v 1.141 2023/03/25 09:21:17 tb Exp $ */
+/* $OpenBSD: wycheproof.go,v 1.142 2023/04/06 08:38:53 tb Exp $ */
/*
* Copyright (c) 2018 Joel Sing <jsing@openbsd.org>
* Copyright (c) 2018,2019,2022 Theo Buehler <tb@openbsd.org>
import (
"bytes"
- "crypto/sha1"
- "crypto/sha256"
- "crypto/sha512"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
- "hash"
"io/ioutil"
"log"
"os"
return -1, fmt.Errorf("unknown NID %q", ns)
}
-func hashFromString(hs string) (hash.Hash, error) {
- switch hs {
- case "SHA-1":
- return sha1.New(), nil
- case "SHA-224":
- return sha256.New224(), nil
- case "SHA-256":
- return sha256.New(), nil
- case "SHA-384":
- return sha512.New384(), nil
- case "SHA-512":
- return sha512.New(), nil
- default:
- return nil, fmt.Errorf("unknown hash %q", hs)
- }
-}
-
func hashEvpMdFromString(hs string) (*C.EVP_MD, error) {
switch hs {
case "SHA-1":
}
}
+func hashEvpDigestMessage(md *C.EVP_MD, msg []byte) ([]byte, C.size, error) {
+ size := C.EVP_MD_size(md)
+ if size <= 0 || size > C.EVP_MAX_MD_SIZE {
+ return nil, 0, fmt.Errorf("unexpected MD size %d", size)
+ }
+
+ msgLen := len(msg)
+ if msgLen == 0 {
+ msg = append(msg, 0)
+ }
+
+ digest := make([]byte, size)
+
+ if C.EVP_Digest(unsafe.Pointer(&msg[0]), C.size_t(msgLen), (*C.uchar)(unsafe.Pointer(&digest[0])), nil, md, nil) != 1 {
+ return nil, 0, fmt.Errorf("EVP_Digest failed")
+ }
+
+ return digest, int(size), nil
+}
+
func checkAesCbcPkcs5(ctx *C.EVP_CIPHER_CTX, doEncrypt int, key []byte, keyLen int,
iv []byte, ivLen int, in []byte, inLen int, out []byte, outLen int,
wt *wycheproofTestAesCbcPkcs5) bool {
return cDer, derLen
}
-func runDSATest(dsa *C.DSA, variant testVariant, h hash.Hash, wt *wycheproofTestDSA) bool {
+func runDSATest(dsa *C.DSA, md *C.EVP_MD, variant testVariant, wt *wycheproofTestDSA) bool {
msg, err := hex.DecodeString(wt.Msg)
if err != nil {
log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
}
- h.Reset()
- h.Write(msg)
- msg = h.Sum(nil)
-
- msgLen := len(msg)
- if msgLen == 0 {
- msg = append(msg, 0)
+ msg, msgLen, err := hashEvpDigestMessage(md, msg)
+ if err != nil {
+ log.Fatalf("%v", err)
}
var ret C.int
log.Fatalf("DSA_set0_key returned %d", ret)
}
- h, err := hashFromString(wtg.SHA)
+ md, err := hashEvpMdFromString(wtg.SHA)
if err != nil {
log.Fatalf("Failed to get hash: %v", err)
}
success := true
for _, wt := range wtg.Tests {
- if !runDSATest(dsa, variant, h, wt) {
+ if !runDSATest(dsa, md, variant, wt) {
success = false
}
- if !runDSATest(dsaDER, variant, h, wt) {
+ if !runDSATest(dsaDER, md, variant, wt) {
success = false
}
- if !runDSATest(dsaPEM, variant, h, wt) {
+ if !runDSATest(dsaPEM, md, variant, wt) {
success = false
}
}
return success
}
-func runECDSATest(ecKey *C.EC_KEY, nid int, h hash.Hash, variant testVariant, wt *wycheproofTestECDSA) bool {
+func runECDSATest(ecKey *C.EC_KEY, md *C.EVP_MD, nid int, variant testVariant, wt *wycheproofTestECDSA) bool {
msg, err := hex.DecodeString(wt.Msg)
if err != nil {
log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
}
- h.Reset()
- h.Write(msg)
- msg = h.Sum(nil)
-
- msgLen := len(msg)
- if msgLen == 0 {
- msg = append(msg, 0)
+ msg, msgLen, err := hashEvpDigestMessage(md, msg)
+ if err != nil {
+ log.Fatalf("%v", err)
}
var ret C.int
if err != nil {
log.Fatalf("Failed to get MD NID: %v", err)
}
- h, err := hashFromString(wtg.SHA)
+ md, err := hashEvpMdFromString(wtg.SHA)
if err != nil {
log.Fatalf("Failed to get hash: %v", err)
}
success := true
for _, wt := range wtg.Tests {
- if !runECDSATest(ecKey, nid, h, variant, wt) {
+ if !runECDSATest(ecKey, md, nid, variant, wt) {
success = false
}
}
if err != nil {
log.Fatalf("Failed to get MD NID: %v", err)
}
- h, err := hashFromString(wtg.SHA)
+ md, err := hashEvpMdFromString(wtg.SHA)
if err != nil {
log.Fatalf("Failed to get hash: %v", err)
}
success := true
for _, wt := range wtg.Tests {
- if !runECDSATest(ecKey, nid, h, Webcrypto, wt) {
+ if !runECDSATest(ecKey, md, nid, Webcrypto, wt) {
success = false
}
}
return success
}
-func runRsassaTest(rsa *C.RSA, h hash.Hash, sha *C.EVP_MD, mgfSha *C.EVP_MD, sLen int, wt *wycheproofTestRsassa) bool {
+func runRsassaTest(rsa *C.RSA, sha *C.EVP_MD, mgfSha *C.EVP_MD, sLen int, wt *wycheproofTestRsassa) bool {
msg, err := hex.DecodeString(wt.Msg)
if err != nil {
log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
}
- h.Reset()
- h.Write(msg)
- msg = h.Sum(nil)
+ msg, _, err = hashEvpDigestMessage(sha, msg)
+ if err != nil {
+ log.Fatalf("%v", err)
+ }
sig, err := hex.DecodeString(wt.Sig)
if err != nil {
log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
}
- msgLen, sigLen := len(msg), len(sig)
- if msgLen == 0 {
- msg = append(msg, 0)
- }
+ sigLen := len(sig)
if sigLen == 0 {
sig = append(sig, 0)
}
rsaN = nil
rsaE = nil
- h, err := hashFromString(wtg.SHA)
- if err != nil {
- log.Fatalf("Failed to get hash: %v", err)
- }
-
sha, err := hashEvpMdFromString(wtg.SHA)
if err != nil {
log.Fatalf("Failed to get hash: %v", err)
success := true
for _, wt := range wtg.Tests {
- if !runRsassaTest(rsa, h, sha, mgfSha, wtg.SLen, wt) {
+ if !runRsassaTest(rsa, sha, mgfSha, wtg.SLen, wt) {
success = false
}
}
return success
}
-func runRSATest(rsa *C.RSA, nid int, h hash.Hash, wt *wycheproofTestRSA) bool {
+func runRSATest(rsa *C.RSA, md *C.EVP_MD, nid int, wt *wycheproofTestRSA) bool {
msg, err := hex.DecodeString(wt.Msg)
if err != nil {
log.Fatalf("Failed to decode message %q: %v", wt.Msg, err)
}
- h.Reset()
- h.Write(msg)
- msg = h.Sum(nil)
+ msg, msgLen, err := hashEvpDigestMessage(md, msg)
+ if err != nil {
+ log.Fatalf("%v", err)
+ }
sig, err := hex.DecodeString(wt.Sig)
if err != nil {
log.Fatalf("Failed to decode signature %q: %v", wt.Sig, err)
}
- msgLen, sigLen := len(msg), len(sig)
- if msgLen == 0 {
- msg = append(msg, 0)
- }
+ sigLen := len(sig)
if sigLen == 0 {
sig = append(sig, 0)
}
if err != nil {
log.Fatalf("Failed to get MD NID: %v", err)
}
- h, err := hashFromString(wtg.SHA)
+ md, err := hashEvpMdFromString(wtg.SHA)
if err != nil {
log.Fatalf("Failed to get hash: %v", err)
}
success := true
for _, wt := range wtg.Tests {
- if !runRSATest(rsa, nid, h, wt) {
+ if !runRSATest(rsa, md, nid, wt) {
success = false
}
}