diff --git a/client.go b/client.go index efae786..29eb01f 100644 --- a/client.go +++ b/client.go @@ -217,17 +217,20 @@ func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error { if c.TrustCertificate != nil { switch c.TrustCertificate(hostname, cert) { case TrustOnce: - c.KnownHosts.AddTemporary(hostname, cert) + fingerprint := NewFingerprint(cert.Raw, cert.NotAfter) + c.KnownHosts.Add(hostname, fingerprint) return nil case TrustAlways: - c.KnownHosts.Add(hostname, cert) + fingerprint := NewFingerprint(cert.Raw, cert.NotAfter) + c.KnownHosts.Add(hostname, fingerprint) + c.KnownHosts.Write(hostname, fingerprint) return nil } } return errors.New("gemini: certificate not trusted") } - fingerprint := NewFingerprint(cert) + fingerprint := NewFingerprint(cert.Raw, cert.NotAfter) if knownHost.Hex == fingerprint.Hex { return nil } diff --git a/doc.go b/doc.go index ea988ef..e858ee7 100644 --- a/doc.go +++ b/doc.go @@ -1,9 +1,10 @@ /* Package gemini implements the Gemini protocol. -Get makes a Gemini request: +Client is a Gemini client. - resp, err := gemini.Get("gemini://example.com") + client := &gemini.Client{} + resp, err := client.Get("gemini://example.com") if err != nil { // handle error } @@ -13,15 +14,6 @@ Get makes a Gemini request: } // ... -For control over client behavior, create a Client: - - client := &gemini.Client{} - resp, err := client.Get("gemini://example.com") - if err != nil { - // handle error - } - // ... - Server is a Gemini server. server := &gemini.Server{ diff --git a/examples/auth.go b/examples/auth.go index b805439..8976fe8 100644 --- a/examples/auth.go +++ b/examples/auth.go @@ -64,7 +64,7 @@ func main() { } func getSession(cert *x509.Certificate) (*session, bool) { - fingerprint := gemini.NewFingerprint(cert) + fingerprint := gemini.NewFingerprint(cert.Raw, cert.NotAfter) session, ok := sessions[fingerprint.Hex] return session, ok } @@ -79,7 +79,8 @@ func login(w *gemini.ResponseWriter, r *gemini.Request) { w.WriteHeader(gemini.StatusInput, "Username") return } - fingerprint := gemini.NewFingerprint(r.Certificate.Leaf) + cert := r.Certificate.Leaf + fingerprint := gemini.NewFingerprint(cert.Raw, cert.NotAfter) sessions[fingerprint.Hex] = &session{ username: username, } @@ -116,7 +117,8 @@ func logout(w *gemini.ResponseWriter, r *gemini.Request) { w.WriteStatus(gemini.StatusCertificateRequired) return } - fingerprint := gemini.NewFingerprint(r.Certificate.Leaf) + cert := r.Certificate.Leaf + fingerprint := gemini.NewFingerprint(cert.Raw, cert.NotAfter) delete(sessions, fingerprint.Hex) fmt.Fprintln(w, "Successfully logged out.") fmt.Fprintln(w, "=> / Index") diff --git a/examples/client.go b/examples/client.go index 012e6df..176c8f8 100644 --- a/examples/client.go +++ b/examples/client.go @@ -10,9 +10,11 @@ import ( "io/ioutil" "log" "os" + "path/filepath" "time" "git.sr.ht/~adnano/go-gemini" + "git.sr.ht/~adnano/go-xdg" ) const trustPrompt = `The certificate offered by %s is of unknown trust. Its fingerprint is: @@ -31,9 +33,9 @@ var ( func init() { client.Timeout = 30 * time.Second - client.KnownHosts.LoadDefault() + client.KnownHosts.Load(filepath.Join(xdg.DataHome(), "gemini", "known_hosts")) client.TrustCertificate = func(hostname string, cert *x509.Certificate) gemini.Trust { - fingerprint := gemini.NewFingerprint(cert) + fingerprint := gemini.NewFingerprint(cert.Raw, cert.NotAfter) fmt.Printf(trustPrompt, hostname, fingerprint.Hex) scanner.Scan() switch scanner.Text() { diff --git a/gemini.go b/gemini.go index bc3a523..a8285e0 100644 --- a/gemini.go +++ b/gemini.go @@ -2,7 +2,6 @@ package gemini import ( "errors" - "sync" ) var crlf = []byte("\r\n") @@ -13,26 +12,3 @@ var ( ErrInvalidResponse = errors.New("gemini: invalid response") ErrBodyNotAllowed = errors.New("gemini: response body not allowed") ) - -// defaultClient is the default client. It is used by Get and Do. -var defaultClient Client - -// Get performs a Gemini request for the given url. -func Get(url string) (*Response, error) { - setupDefaultClientOnce() - return defaultClient.Get(url) -} - -// Do performs a Gemini request and returns a Gemini response. -func Do(req *Request) (*Response, error) { - setupDefaultClientOnce() - return defaultClient.Do(req) -} - -var defaultClientOnce sync.Once - -func setupDefaultClientOnce() { - defaultClientOnce.Do(func() { - defaultClient.KnownHosts.LoadDefault() - }) -} diff --git a/tofu.go b/tofu.go index 4d87e48..e172f88 100644 --- a/tofu.go +++ b/tofu.go @@ -3,12 +3,12 @@ package gemini import ( "bufio" "crypto/sha512" - "crypto/x509" "fmt" "io" "os" "strconv" "strings" + "time" ) // Trust represents the trustworthiness of a certificate. @@ -27,6 +27,43 @@ type KnownHosts struct { file *os.File } +// Add adds a fingerprint to the list of known hosts. +func (k *KnownHosts) Add(hostname string, fingerprint Fingerprint) { + if k.hosts == nil { + k.hosts = map[string]Fingerprint{} + } + k.hosts[hostname] = fingerprint +} + +// Lookup returns the fingerprint of the certificate corresponding to +// the given hostname. +func (k *KnownHosts) Lookup(hostname string) (Fingerprint, bool) { + c, ok := k.hosts[hostname] + return c, ok +} + +// Write appends a fingerprint to the known hosts file. +func (k *KnownHosts) Write(hostname string, fingerprint Fingerprint) { + if k.file != nil { + k.writeKnownHost(k.file, hostname, fingerprint) + } +} + +// WriteAll writes all of the known hosts to the provided io.Writer. +func (k *KnownHosts) WriteAll(w io.Writer) error { + for h, c := range k.hosts { + if _, err := k.writeKnownHost(w, h, c); err != nil { + return err + } + } + return nil +} + +// writeKnownHost writes the fingerprint to the provided io.Writer. +func (k *KnownHosts) writeKnownHost(w io.Writer, hostname string, f Fingerprint) (int, error) { + return fmt.Fprintf(w, "%s %s %s %d\n", hostname, f.Algorithm, f.Hex, f.Expires) +} + // Load loads the known hosts from the provided path. // New known hosts will be appended to the file. func (k *KnownHosts) Load(path string) error { @@ -45,31 +82,6 @@ func (k *KnownHosts) Load(path string) error { return nil } -// Add adds a certificate to the list of known hosts. -// If KnownHosts was loaded from a file, Add will append to the file. -func (k *KnownHosts) Add(hostname string, cert *x509.Certificate) { - k.add(hostname, cert, true) -} - -func (k *KnownHosts) add(hostname string, cert *x509.Certificate, write bool) { - if k.hosts == nil { - k.hosts = map[string]Fingerprint{} - } - fingerprint := NewFingerprint(cert) - k.hosts[hostname] = fingerprint - // Append to the file - if write && k.file != nil { - appendKnownHost(k.file, hostname, fingerprint) - } -} - -// Lookup returns the fingerprint of the certificate corresponding to -// the given hostname. -func (k *KnownHosts) Lookup(hostname string) (Fingerprint, bool) { - c, ok := k.hosts[hostname] - return c, ok -} - // Parse parses the provided reader and adds the parsed known hosts to the list. // Invalid lines are ignored. func (k *KnownHosts) Parse(r io.Reader) { @@ -104,17 +116,6 @@ func (k *KnownHosts) Parse(r io.Reader) { } } -// Write writes the known hosts to the provided io.Writer. -func (k *KnownHosts) Write(w io.Writer) { - for h, c := range k.hosts { - appendKnownHost(w, h, c) - } -} - -func appendKnownHost(w io.Writer, hostname string, f Fingerprint) (int, error) { - return fmt.Fprintf(w, "%s %s %s %d\n", hostname, f.Algorithm, f.Hex, f.Expires) -} - // Fingerprint represents a fingerprint using a certain algorithm. type Fingerprint struct { Algorithm string // fingerprint algorithm e.g. SHA-512 @@ -122,9 +123,9 @@ type Fingerprint struct { Expires int64 // unix time of the fingerprint expiration date } -// NewFingerprint returns the SHA-512 fingerprint of the provided certificate. -func NewFingerprint(cert *x509.Certificate) Fingerprint { - sum512 := sha512.Sum512(cert.Raw) +// NewFingerprint returns the SHA-512 fingerprint of the provided raw data. +func NewFingerprint(raw []byte, expires time.Time) Fingerprint { + sum512 := sha512.Sum512(raw) var b strings.Builder for i, f := range sum512 { if i > 0 { @@ -135,6 +136,6 @@ func NewFingerprint(cert *x509.Certificate) Fingerprint { return Fingerprint{ Algorithm: "SHA-512", Hex: b.String(), - Expires: cert.NotAfter.Unix(), + Expires: expires.Unix(), } }