diff --git a/cert.go b/cert.go index d3acbc8..23cdcd4 100644 --- a/cert.go +++ b/cert.go @@ -52,16 +52,9 @@ func (c *CertificateStore) Add(scope string, cert tls.Certificate) error { } // Lookup returns the certificate for the given scope. -func (c *CertificateStore) Lookup(scope string) (*tls.Certificate, error) { +func (c *CertificateStore) Lookup(scope string) (tls.Certificate, bool) { cert, ok := c.store[scope] - if !ok { - return nil, ErrCertificateNotFound - } - // Ensure that the certificate is not expired - if cert.Leaf != nil && cert.Leaf.NotAfter.Before(time.Now()) { - return &cert, ErrCertificateExpired - } - return &cert, nil + return cert, ok } // Load loads certificates from the given path. diff --git a/client.go b/client.go index 306e758..7dd7b35 100644 --- a/client.go +++ b/client.go @@ -4,6 +4,7 @@ import ( "bufio" "crypto/tls" "crypto/x509" + "errors" "net" "net/url" "path" @@ -164,7 +165,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) { } } else if len(via) > 5 { // Default policy of no more than 5 redirects - return resp, ErrTooManyRedirects + return resp, errors.New("gemini: too many redirects") } return c.do(redirect, via) } @@ -182,13 +183,14 @@ func (c *Client) getClientCertificate(req *Request) (*tls.Certificate, error) { // Search recursively for the certificate scope := req.URL.Hostname() + strings.TrimSuffix(req.URL.Path, "/") for { - cert, err := c.Certificates.Lookup(scope) - if err == nil { - // Store the certificate - req.Certificate = cert - return cert, err - } - if err == ErrCertificateExpired { + cert, ok := c.Certificates.Lookup(scope) + if ok { + // Ensure that the certificate is not expired + if cert.Leaf != nil && !time.Now().After(cert.Leaf.NotAfter) { + // Store the certificate + req.Certificate = &cert + return &cert, nil + } break } scope = path.Dir(scope) @@ -216,21 +218,27 @@ func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error { return nil } // Check the known hosts - err := c.KnownHosts.Lookup(hostname, cert) - switch err { - case ErrCertificateExpired, ErrCertificateNotFound: - // See if the client trusts the certificate - if c.TrustCertificate != nil { - switch c.TrustCertificate(hostname, cert) { - case TrustOnce: - c.KnownHosts.AddTemporary(hostname, cert) - return nil - case TrustAlways: - c.KnownHosts.Add(hostname, cert) - return nil - } + knownHost, ok := c.KnownHosts.Lookup(hostname) + if ok && time.Now().After(cert.NotAfter) { + // Not expired + fingerprint := NewFingerprint(cert) + if knownHost.Hex != fingerprint.Hex { + return errors.New("gemini: fingerprint does not match") } - return ErrCertificateNotTrusted + return nil } - return err + + // Unknown certificate + // See if the client trusts the certificate + if c.TrustCertificate != nil { + switch c.TrustCertificate(hostname, cert) { + case TrustOnce: + c.KnownHosts.AddTemporary(hostname, cert) + return nil + case TrustAlways: + c.KnownHosts.Add(hostname, cert) + return nil + } + } + return errors.New("gemini: certificate not trusted") } diff --git a/examples/client.go b/examples/client.go index 35266f9..9f5efae 100644 --- a/examples/client.go +++ b/examples/client.go @@ -33,7 +33,8 @@ func init() { client.Timeout = 30 * time.Second client.KnownHosts.LoadDefault() client.TrustCertificate = func(hostname string, cert *x509.Certificate) gemini.Trust { - fmt.Printf(trustPrompt, hostname, gemini.Fingerprint(cert)) + fingerprint := gemini.NewFingerprint(cert) + fmt.Printf(trustPrompt, hostname, fingerprint.Hex) scanner.Scan() switch scanner.Text() { case "t": diff --git a/gemini.go b/gemini.go index 3cfd379..bc3a523 100644 --- a/gemini.go +++ b/gemini.go @@ -9,13 +9,9 @@ var crlf = []byte("\r\n") // Errors. var ( - ErrInvalidURL = errors.New("gemini: invalid URL") - ErrInvalidResponse = errors.New("gemini: invalid response") - ErrCertificateExpired = errors.New("gemini: certificate expired") - ErrCertificateNotFound = errors.New("gemini: certificate not found") - ErrCertificateNotTrusted = errors.New("gemini: certificate not trusted") - ErrBodyNotAllowed = errors.New("gemini: response body not allowed") - ErrTooManyRedirects = errors.New("gemini: too many redirects") + ErrInvalidURL = errors.New("gemini: invalid URL") + ErrInvalidResponse = errors.New("gemini: invalid response") + ErrBodyNotAllowed = errors.New("gemini: response body not allowed") ) // defaultClient is the default client. It is used by Get and Do. diff --git a/server.go b/server.go index 8db210e..0590bb9 100644 --- a/server.go +++ b/server.go @@ -3,6 +3,7 @@ package gemini import ( "bufio" "crypto/tls" + "errors" "log" "net" "net/url" @@ -150,12 +151,12 @@ func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) { if _, ok := s.hosts[hostname]; !ok { - return nil, ErrCertificateNotFound + return nil, errors.New("hostname not registered") } - cert, err := s.Certificates.Lookup(hostname) - switch err { - case ErrCertificateNotFound, ErrCertificateExpired: + // Generate a new certificate if it is missing or expired + cert, ok := s.Certificates.Lookup(hostname) + if !ok || cert.Leaf != nil && !time.Now().After(cert.Leaf.NotAfter) { if s.CreateCertificate != nil { cert, err := s.CreateCertificate(hostname) if err == nil { @@ -165,9 +166,9 @@ func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) { } return &cert, err } + return nil, errors.New("no certificate") } - - return cert, err + return &cert, nil } // respond responds to a connection. diff --git a/tofu.go b/tofu.go index 05a8624..ac73176 100644 --- a/tofu.go +++ b/tofu.go @@ -8,9 +8,7 @@ import ( "io" "os" "path/filepath" - "strconv" "strings" - "time" ) // Trust represents the trustworthiness of a certificate. @@ -25,7 +23,7 @@ const ( // KnownHosts represents a list of known hosts. // The zero value for KnownHosts is an empty list ready to use. type KnownHosts struct { - hosts map[string]certInfo + hosts map[string]Fingerprint file *os.File } @@ -80,53 +78,34 @@ func (k *KnownHosts) AddTemporary(hostname string, cert *x509.Certificate) { func (k *KnownHosts) add(hostname string, cert *x509.Certificate, write bool) { if k.hosts == nil { - k.hosts = map[string]certInfo{} + k.hosts = map[string]Fingerprint{} } - info := certInfo{ - Algorithm: "SHA-512", - Fingerprint: Fingerprint(cert), - Expires: cert.NotAfter.Unix(), - } - k.hosts[hostname] = info + fingerprint := NewFingerprint(cert) + k.hosts[hostname] = fingerprint // Append to the file if write && k.file != nil { - appendKnownHost(k.file, hostname, info) + appendKnownHost(k.file, hostname, fingerprint) } } -// Lookup looks for the provided certificate in the list of known hosts. -// If the hostname is not in the list, Lookup returns ErrCertificateNotFound. -// If the fingerprint doesn't match, Lookup returns ErrCertificateNotTrusted. -// Otherwise, Lookup returns nil. -func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error { - now := time.Now().Unix() - fingerprint := Fingerprint(cert) - if c, ok := k.hosts[hostname]; ok { - if c.Expires <= now { - // Certificate is expired - return ErrCertificateExpired - } - if c.Fingerprint != fingerprint { - // Fingerprint does not match - return ErrCertificateNotTrusted - } - // Certificate is found - return nil - } - return ErrCertificateNotFound +// Lookup returns the fingerprint of the certificate corresponding to +// the given hostname. +func (k *KnownHosts) Lookup(hostname string) (Fingerprint, bool) { + c, ok := k.hosts[hostname] + return c, ok } // Parse parses the provided reader and adds the parsed known hosts to the list. // Invalid lines are ignored. func (k *KnownHosts) Parse(r io.Reader) { if k.hosts == nil { - k.hosts = map[string]certInfo{} + k.hosts = map[string]Fingerprint{} } scanner := bufio.NewScanner(r) for scanner.Scan() { text := scanner.Text() parts := strings.Split(text, " ") - if len(parts) < 4 { + if len(parts) < 3 { continue } @@ -136,15 +115,10 @@ func (k *KnownHosts) Parse(r io.Reader) { continue } fingerprint := parts[2] - expires, err := strconv.ParseInt(parts[3], 10, 0) - if err != nil { - continue - } - k.hosts[hostname] = certInfo{ - Algorithm: algorithm, - Fingerprint: fingerprint, - Expires: expires, + k.hosts[hostname] = Fingerprint{ + Algorithm: algorithm, + Hex: fingerprint, } } } @@ -156,18 +130,18 @@ func (k *KnownHosts) Write(w io.Writer) { } } -type certInfo struct { - Algorithm string // fingerprint algorithm e.g. SHA-512 - Fingerprint string // fingerprint in hexadecimal, with ':' between each octet - Expires int64 // unix time of certificate notAfter date +func appendKnownHost(w io.Writer, hostname string, f Fingerprint) (int, error) { + return fmt.Fprintf(w, "%s %s %s\n", hostname, f.Algorithm, f.Hex) } -func appendKnownHost(w io.Writer, hostname string, c certInfo) (int, error) { - return fmt.Fprintf(w, "%s %s %s %d\n", hostname, c.Algorithm, c.Fingerprint, c.Expires) +// Fingerprint represents a fingerprint using a certain algorithm. +type Fingerprint struct { + Algorithm string // fingerprint algorithm e.g. SHA-512 + Hex string // fingerprint in hexadecimal, with ':' between each octet } -// Fingerprint returns the SHA-512 fingerprint of the provided certificate. -func Fingerprint(cert *x509.Certificate) string { +// NewFingerprint returns the SHA-512 fingerprint of the provided certificate. +func NewFingerprint(cert *x509.Certificate) Fingerprint { sum512 := sha512.Sum512(cert.Raw) var b strings.Builder for i, f := range sum512 { @@ -176,7 +150,10 @@ func Fingerprint(cert *x509.Certificate) string { } fmt.Fprintf(&b, "%02X", f) } - return b.String() + return Fingerprint{ + Algorithm: "SHA-512", + Hex: b.String(), + } } // defaultKnownHostsPath returns the default known_hosts path.