Add test coverage for peer certificate info and connection info.
authorjsing <jsing@openbsd.org>
Tue, 13 Oct 2015 13:58:33 +0000 (13:58 +0000)
committerjsing <jsing@openbsd.org>
Tue, 13 Oct 2015 13:58:33 +0000 (13:58 +0000)
regress/lib/libtls/gotls/tls.go
regress/lib/libtls/gotls/tls_test.go

index 6dc51b8..74c34b4 100644 (file)
@@ -15,6 +15,7 @@ import "C"
 import (
        "errors"
        "fmt"
+       "time"
        "unsafe"
 )
 
@@ -115,6 +116,84 @@ func (t *TLS) Error() string {
        return ""
 }
 
+// PeerCertProvided returns whether the peer provided a certificate.
+func (t *TLS) PeerCertProvided() bool {
+       return C.tls_peer_cert_provided(t.ctx) == 1
+}
+
+// PeerCertContainsName checks whether the peer certificate contains
+// the specified name.
+func (t *TLS) PeerCertContainsName(name string) bool {
+       n := C.CString(name)
+       defer C.free(unsafe.Pointer(n))
+       return C.tls_peer_cert_contains_name(t.ctx, n) == 1
+}
+
+// PeerCertIssuer returns the issuer of the peer certificate.
+func (t *TLS) PeerCertIssuer() (string, error) {
+       issuer := C.tls_peer_cert_issuer(t.ctx)
+       if issuer == nil {
+               return "", errors.New("no issuer returned")
+       }
+       return C.GoString(issuer), nil
+}
+
+// PeerCertSubject returns the subject of the peer certificate.
+func (t *TLS) PeerCertSubject() (string, error) {
+       subject := C.tls_peer_cert_subject(t.ctx)
+       if subject == nil {
+               return "", errors.New("no subject returned")
+       }
+       return C.GoString(subject), nil
+}
+
+// PeerCertHash returns a hash of the peer certificate.
+func (t *TLS) PeerCertHash() (string, error) {
+       hash := C.tls_peer_cert_hash(t.ctx)
+       if hash == nil {
+               return "", errors.New("no hash returned")
+       }
+       return C.GoString(hash), nil
+}
+
+// PeerCertNotBefore returns the notBefore time from the peer
+// certificate.
+func (t *TLS) PeerCertNotBefore() (time.Time, error) {
+       notBefore := C.tls_peer_cert_notbefore(t.ctx)
+       if notBefore == -1 {
+               return time.Time{}, errors.New("no notBefore time returned")
+       }
+       return time.Unix(int64(notBefore), 0), nil
+}
+
+// PeerCertNotAfter returns the notAfter time from the peer
+// certificate.
+func (t *TLS) PeerCertNotAfter() (time.Time, error) {
+       notAfter := C.tls_peer_cert_notafter(t.ctx)
+       if notAfter == -1 {
+               return time.Time{}, errors.New("no notAfter time")
+       }
+       return time.Unix(int64(notAfter), 0), nil
+}
+
+// ConnVersion returns the protocol version of the connection.
+func (t *TLS) ConnVersion() (string, error) {
+       ver := C.tls_conn_version(t.ctx)
+       if ver == nil {
+               return "", errors.New("no connection version")
+       }
+       return C.GoString(ver), nil
+}
+
+// ConnCipher returns the cipher suite used for the connection.
+func (t *TLS) ConnCipher() (string, error) {
+       cipher := C.tls_conn_cipher(t.ctx)
+       if cipher == nil {
+               return "", errors.New("no connection cipher")
+       }
+       return C.GoString(cipher), nil
+}
+
 // Connect attempts to establish an TLS connection to the specified host on
 // the given port. The host may optionally contain a colon separated port
 // value if the port string is specified as an empty string.
index 2afcf93..2331ec0 100644 (file)
@@ -10,6 +10,18 @@ import (
        "os"
        "strings"
        "testing"
+       "time"
+)
+
+const (
+       httpContent = "Hello, TLS!"
+
+       certHash = "SHA256:448f628a8a65aa18560e53a80c53acb38c51b427df0334082349141147dc9bf6"
+)
+
+var (
+       certNotBefore = time.Unix(0, 0)
+       certNotAfter = certNotBefore.Add(1000000 * time.Hour)
 )
 
 // createCAFile writes a PEM encoded version of the certificate out to a
@@ -30,9 +42,7 @@ func createCAFile(cert []byte) (string, error) {
        return f.Name(), nil
 }
 
-const httpContent = "Hello, TLS!"
-
-func TestTLSBasic(t *testing.T) {
+func newTestServer() (*httptest.Server, *url.URL, string, error) {
        ts := httptest.NewTLSServer(
                http.HandlerFunc(
                        func(w http.ResponseWriter, r *http.Request) {
@@ -40,18 +50,27 @@ func TestTLSBasic(t *testing.T) {
                        },
                ),
        )
-       defer ts.Close()
 
        u, err := url.Parse(ts.URL)
        if err != nil {
-               t.Fatalf("Failed to parse URL %q: %v", ts.URL, err)
+               return nil, nil, "", fmt.Errorf("failed to parse URL %q: %v", ts.URL, err)
        }
 
        caFile, err := createCAFile(ts.TLS.Certificates[0].Certificate[0])
        if err != nil {
-               t.Fatalf("Failed to create CA file: %v", err)
+               return nil, nil, "", fmt.Errorf("failed to create CA file: %v", err)
+       }
+
+       return ts, u, caFile, nil
+}
+
+func TestTLSBasic(t *testing.T) {
+       ts, u, caFile, err := newTestServer()
+       if err != nil {
+               t.Fatalf("Failed to start test server: %v", err)
        }
        defer os.Remove(caFile)
+       defer ts.Close()
 
        if err := Init(); err != nil {
                t.Fatal(err)
@@ -98,3 +117,130 @@ func TestTLSBasic(t *testing.T) {
                t.Errorf("Response does not contain %q", httpContent)
        }
 }
+
+func TestTLSInfo(t *testing.T) {
+       ts, u, caFile, err := newTestServer()
+       if err != nil {
+               t.Fatalf("Failed to start test server: %v", err)
+       }
+       defer os.Remove(caFile)
+       defer ts.Close()
+
+       if err := Init(); err != nil {
+               t.Fatal(err)
+       }
+
+       cfg, err := NewConfig()
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer cfg.Free()
+       cfg.SetCAFile(caFile)
+
+       tls, err := NewClient(cfg)
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer tls.Free()
+
+       t.Logf("Connecting to %s", u.Host)
+
+       if err := tls.Connect(u.Host, ""); err != nil {
+               t.Fatal(err)
+       }
+       defer func() {
+               if err := tls.Close(); err != nil {
+                       t.Fatalf("Close failed: %v", err)
+               }
+       }()
+
+       // All of these should fail since the handshake has not completed.
+       if _, err := tls.ConnVersion(); err == nil {
+               t.Error("ConnVersion() return nil error, want error")
+       }
+       if _, err := tls.ConnCipher(); err == nil {
+               t.Error("ConnCipher() return nil error, want error")
+       }
+
+       if got, want := tls.PeerCertProvided(), false; got != want {
+               t.Errorf("PeerCertProvided() = %v, want %v", got, want)
+       }
+       for _, name := range []string{"127.0.0.1", "::1", "example.com"} {
+               if got, want := tls.PeerCertContainsName(name), false; got != want {
+                       t.Errorf("PeerCertContainsName(%q) = %v, want %v", name, got, want)
+               }
+       }
+
+       if _, err := tls.PeerCertIssuer(); err == nil {
+               t.Error("PeerCertIssuer() returned nil error, want error")
+       }
+       if _, err := tls.PeerCertSubject(); err == nil {
+               t.Error("PeerCertSubject() returned nil error, want error")
+       }
+       if _, err := tls.PeerCertHash(); err == nil {
+               t.Error("PeerCertHash() returned nil error, want error")
+       }
+       if _, err := tls.PeerCertNotBefore(); err == nil {
+               t.Error("PeerCertNotBefore() returned nil error, want error")
+       }
+       if _, err := tls.PeerCertNotAfter(); err == nil {
+               t.Error("PeerCertNotAfter() returned nil error, want error")
+       }
+
+       // Complete the handshake...
+       if err := tls.Handshake(); err != nil {
+               t.Fatalf("Handshake failed: %v", err)
+       }
+
+       if version, err := tls.ConnVersion(); err != nil {
+               t.Errorf("ConnVersion() return error: %v", err)
+       } else {
+               t.Logf("Protocol version: %v", version)
+       }
+       if cipher, err := tls.ConnCipher(); err != nil {
+               t.Errorf("ConnCipher() return error: %v", err)
+       } else {
+               t.Logf("Cipher: %v", cipher)
+       }
+
+       if got, want := tls.PeerCertProvided(), true; got != want {
+               t.Errorf("PeerCertProvided() = %v, want %v", got, want)
+       }
+       for _, name := range []string{"127.0.0.1", "::1", "example.com"} {
+               if got, want := tls.PeerCertContainsName(name), true; got != want {
+                       t.Errorf("PeerCertContainsName(%q) = %v, want %v", name, got, want)
+               }
+       }
+
+       if issuer, err := tls.PeerCertIssuer(); err != nil {
+               t.Errorf("PeerCertIssuer() returned error: %v", err)
+       } else {
+               t.Logf("Issuer: %v", issuer)
+       }
+       if subject, err := tls.PeerCertSubject(); err != nil {
+               t.Errorf("PeerCertSubject() returned error: %v", err)
+       } else {
+               t.Logf("Subject: %v", subject)
+       }
+       if hash, err := tls.PeerCertHash(); err != nil {
+               t.Errorf("PeerCertHash() returned error: %v", err)
+       } else if hash != certHash {
+               t.Errorf("Got cert hash %q, want %q", hash, certHash)
+       } else {
+               t.Logf("Hash: %v", hash)
+       }
+       if notBefore, err := tls.PeerCertNotBefore(); err != nil {
+               t.Errorf("PeerCertNotBefore() returned error: %v", err)
+       } else if !certNotBefore.Equal(notBefore) {
+               t.Errorf("Got cert notBefore %v, want %v", notBefore.UTC(), certNotBefore.UTC())
+       } else {
+               t.Logf("NotBefore: %v", notBefore.UTC())
+       }
+       if notAfter, err := tls.PeerCertNotAfter(); err != nil {
+               t.Errorf("PeerCertNotAfter() returned error: %v", err)
+       } else if !certNotAfter.Equal(notAfter) {
+               t.Errorf("Got cert notAfter %v, want %v", notAfter.UTC(), certNotAfter.UTC())
+       } else {
+               t.Logf("NotAfter: %v", notAfter.UTC())
+       }
+}