1
1
Fork 0
mirror of https://git.sr.ht/~emersion/tlstunnel synced 2024-05-25 11:06:07 +02:00

Add `tls ca` directive

This commit is contained in:
Simon Ser 2020-09-09 14:08:20 +02:00
parent 6ac58fe450
commit 137be93297
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
2 changed files with 47 additions and 21 deletions

40
main.go
View File

@ -1,7 +1,6 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"log" "log"
"net" "net"
@ -17,9 +16,18 @@ func main() {
srv := NewServer() srv := NewServer()
for _, d := range cfg.ChildrenByName("frontend") { for _, d := range cfg.Children {
if err := parseFrontend(srv, d); err != nil { var err error
log.Fatalf("failed to parse frontend: %v", err) switch d.Name {
case "frontend":
err = parseFrontend(srv, d)
case "tls":
err = parseTLS(srv, d)
default:
log.Fatalf("unknown %q directive", d.Name)
}
if err != nil {
log.Fatalf("directive %q: %v", d.Name, err)
} }
} }
@ -43,7 +51,6 @@ func parseFrontend(srv *Server, d *Directive) error {
return err return err
} }
var certNames []string
for _, listenAddr := range d.Params { for _, listenAddr := range d.Params {
host, port, err := net.SplitHostPort(listenAddr) host, port, err := net.SplitHostPort(listenAddr)
if err != nil { if err != nil {
@ -54,8 +61,9 @@ func parseFrontend(srv *Server, d *Directive) error {
var name string var name string
if host != "" && host != "localhost" && net.ParseIP(host) == nil { if host != "" && host != "localhost" && net.ParseIP(host) == nil {
name = host name = host
certNames = append(certNames, host)
host = "" host = ""
srv.ManagedNames = append(srv.ManagedNames, name)
} }
addr := net.JoinHostPort(host, port) addr := net.JoinHostPort(host, port)
@ -66,10 +74,6 @@ func parseFrontend(srv *Server, d *Directive) error {
} }
} }
if err := srv.certmagic.ManageAsync(context.Background(), certNames); err != nil {
return fmt.Errorf("failed to manage TLS certificates: %v", err)
}
return nil return nil
} }
@ -103,3 +107,19 @@ func parseBackend(backend *Backend, d *Directive) error {
return nil return nil
} }
func parseTLS(srv *Server, d *Directive) error {
for _, child := range d.Children {
switch child.Name {
case "ca":
var caURL string
if err := child.ParseParams(&caURL); err != nil {
return err
}
srv.acmeManager.CA = caURL
default:
return fmt.Errorf("unknown %q directive", child.Name)
}
}
return nil
}

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
@ -11,27 +12,28 @@ import (
) )
type Server struct { type Server struct {
Listeners map[string]*Listener // indexed by listening address Listeners map[string]*Listener // indexed by listening address
Frontends []*Frontend Frontends []*Frontend
certmagic *certmagic.Config ManagedNames []string
acmeManager *certmagic.ACMEManager
certmagic *certmagic.Config
} }
func NewServer() *Server { func NewServer() *Server {
cfg := certmagic.NewDefault() cfg := certmagic.NewDefault()
acme := certmagic.DefaultACME mgr := certmagic.NewACMEManager(cfg, certmagic.DefaultACME)
// TODO: use production CA mgr.Agreed = true
acme.CA = certmagic.LetsEncryptStagingCA
acme.Agreed = true
// TODO: enable HTTP challenge by peeking incoming requests on port 80 // TODO: enable HTTP challenge by peeking incoming requests on port 80
acme.DisableHTTPChallenge = true mgr.DisableHTTPChallenge = true
mgr := certmagic.NewACMEManager(cfg, acme)
cfg.Issuer = mgr cfg.Issuer = mgr
cfg.Revoker = mgr cfg.Revoker = mgr
return &Server{ return &Server{
Listeners: make(map[string]*Listener), Listeners: make(map[string]*Listener),
certmagic: cfg, acmeManager: mgr,
certmagic: cfg,
} }
} }
@ -46,6 +48,10 @@ func (srv *Server) RegisterListener(addr string) *Listener {
} }
func (srv *Server) Start() error { func (srv *Server) Start() error {
if err := srv.certmagic.ManageAsync(context.Background(), srv.ManagedNames); err != nil {
return fmt.Errorf("failed to manage TLS certificates: %v", err)
}
for _, ln := range srv.Listeners { for _, ln := range srv.Listeners {
if err := ln.Start(); err != nil { if err := ln.Start(); err != nil {
return err return err