Skip to content

Commit

Permalink
feat: support re-stapling certificates
Browse files Browse the repository at this point in the history
Previously, certificates would be stapled when they got loaded but would
not be re-stapled automatically, unless the certificate was reloaded.
Now the certificate is automatically re-stapled well before the OCSP
response should expire.

Signed-off-by: Matthew Penner <[email protected]>
  • Loading branch information
matthewpi committed Oct 27, 2024
1 parent 6d819f3 commit b669a7f
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 71 deletions.
204 changes: 175 additions & 29 deletions certwatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import (
"errors"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"

"github.com/fsnotify/fsnotify"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/metric"
goocsp "golang.org/x/crypto/ocsp"

"github.com/matthewpi/certwatcher/internal/ocsp"
"github.com/matthewpi/certwatcher/internal/sets"
Expand Down Expand Up @@ -49,7 +51,22 @@ type Watcher struct {
certPath string
keyPath string

// certμ is used to guard writes to [cert].
//
// We use both a [sync.Mutex] and [atomic.Pointer] for [cert] for two
// different reasons.
//
// The [atomic.Pointer] protects against race-conditions and is perfect for
// protecting data that is written infrequently but constantly being read.
//
// The [sync.Mutex] is used to guard multiple simultaneous reconfigurations
// of [cert]. There are multiple ways that [cert] could be modified, and we
// don't want to have multiple updates running simultaneously, not because
// they will cause a race-condition, but because it can lead to
// unpredictable outcomes.
certμ sync.Mutex
cert atomic.Pointer[tls.Certificate]
ocsp *goocsp.Response
debounce debounced

fsWatcher *fsnotify.Watcher
Expand All @@ -59,6 +76,8 @@ type Watcher struct {
meter metric.Meter
reconfigureTotalCounter metric.Int64Counter
reconfigureErrorCounter metric.Int64Counter
stapleTotalCounter metric.Int64Counter
stapleErrorCounter metric.Int64Counter
}

// New creates a new certwatcher [Watcher], capable of reloading certificates on
Expand Down Expand Up @@ -90,20 +109,33 @@ func New(options Options) (*Watcher, error) {
if err != nil {
return nil, fmt.Errorf("certwatcher: failed to create otel meter: %w", err)
}
w.stapleTotalCounter, err = w.meter.Int64Counter("certwatcher.ocsp.staple.total")
if err != nil {
return nil, fmt.Errorf("certwatcher: failed to create otel meter: %w", err)
}
w.stapleErrorCounter, err = w.meter.Int64Counter("certwatcher.ocsp.staple.errors")
if err != nil {
return nil, fmt.Errorf("certwatcher: failed to create otel meter: %w", err)
}
w.fsWatcher, err = fsnotify.NewWatcher()
if err != nil {
return nil, fmt.Errorf("certwatcher: failed to create fswatcher: %w", err)
}
return w, nil
}

// GetCertificate satisfies tls.Config#GetCertificate. This function should be
// Certificate returns the most recently loaded [*tls.Certificate].
func (w *Watcher) Certificate() *tls.Certificate {
return w.cert.Load()
}

// GetCertificate satisfies [tls.Config.GetCertificate]. This function should be
// used on a tls.Config to use the certificate loaded by certwatcher.
func (w *Watcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
return w.cert.Load(), nil
}

// GetClientCertificate satisfies tls.Config#GetClientCertificate. This function
// GetClientCertificate satisfies [tls.Config.GetClientCertificate]. This function
// should be used on a tls.Config to use the certificate loaded by certwatcher.
func (w *Watcher) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return w.cert.Load(), nil
Expand All @@ -115,10 +147,17 @@ func (w *Watcher) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Cert
// This method is both used for initial configuration and for reconfiguration
// if the certificate paths need to be changed (e.g. config hot-reloading).
func (w *Watcher) Reconfigure(ctx context.Context, certPath, keyPath string) error {
// Load the certificate from disk.
// Lock the certificate so nothing else tries to refresh or reconfigure it.
w.certμ.Lock()
defer w.certμ.Unlock()

// Increment the reconfigure counter.
w.reconfigureTotalCounter.Add(ctx, 1)

// Load the certificate from disk and parse it's leaf.
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
// Increment the error counter.
w.reconfigureErrorCounter.Add(ctx, 1)
return fmt.Errorf("certwatcher: failed to load x509 certificate: %w", err)
}
Expand Down Expand Up @@ -148,16 +187,17 @@ func (w *Watcher) Reconfigure(ctx context.Context, certPath, keyPath string) err
slog.String("not_after", notAfter.Format(time.DateTime)),
)

// Attempt to staple the certificate only if we were able to load its leaf.
cert, err = w.staple(ctx, cert)
if err != nil {
certPtr := &cert

// Staple the certificate.
if err := w.staple(ctx, certPtr); err != nil {
w.logger.LogAttrs(ctx, slog.LevelWarn, "failed to staple certificate", slog.Any("err", err))
}

// Update the certificate we are serving.
// We use swap, so we can display a helpful message if this is the first
// time a certificate was loaded or if it was a reload.
if wCert := w.cert.Swap(&cert); wCert == nil {
if wCert := w.cert.Swap(certPtr); wCert == nil {
w.logger.LogAttrs(ctx, slog.LevelInfo, "certificate loaded", fields...)
} else {
w.logger.LogAttrs(ctx, slog.LevelInfo, "certificate reloaded", fields...)
Expand All @@ -170,44 +210,63 @@ func (w *Watcher) Reconfigure(ctx context.Context, certPath, keyPath string) err
}

// staple attempts to perform OCSP stapling on the supplied certificate.
func (w *Watcher) staple(ctx context.Context, cert tls.Certificate) (tls.Certificate, error) {
// If stapling is disabled, return the supplied certificate.
func (w *Watcher) staple(ctx context.Context, cert *tls.Certificate) error {
// If stapling is disabled, don't do anything.
if w.options.DontStaple {
return cert, nil
}

// If the certificate is already stapled, return it.
if cert.OCSPStaple != nil {
w.logger.DebugContext(ctx, "certificate was already stapled")
return cert, nil
return nil
}

// Attempt to staple the certificate.
s := ocsp.Stapler{Certificate: cert}

w.logger.LogAttrs(
ctx,
slog.LevelInfo,
"attempting to staple certificate...",
slog.Any("ocsp_servers", cert.Leaf.OCSPServer),
slog.Any("issuing_certificate_url", cert.Leaf.IssuingCertificateURL),
)
if err := s.Staple(ctx); err != nil {

// Staple the certificate.
res, err := ocsp.Staple(ctx, cert)
if err != nil {
if errors.Is(err, ocsp.ErrNoOCSPServer) {
w.logger.InfoContext(ctx, "certificate has no ocsp servers, cannot staple")
return cert, nil
w.logger.LogAttrs(ctx, slog.LevelInfo, "certificate has no ocsp servers, unable to staple")
return nil
}
return cert, fmt.Errorf("certwatcher: error stapling certificate: %w", err)
return fmt.Errorf("certwatcher: error stapling certificate: %w", err)
}

// Update the certificate with the stapled one.
cert = s.Certificate
// Store the new OCSP response.
if res == nil {
w.ocsp = nil
return nil
}
w.ocsp = res.Response

var status string
switch res.Status {
case goocsp.Good:
status = "Good"
case goocsp.Revoked:
status = "Revoked"
case goocsp.Unknown:
status = "Unknown"
case goocsp.ServerFailed:
status = "ServerFailed"
}

// OCSPStaple will only be set if `status` is Good.
if cert.OCSPStaple == nil {
w.logger.WarnContext(ctx, "certificate was not stapled")
w.logger.LogAttrs(ctx, slog.LevelWarn, "certificate was not stapled", slog.Group("ocsp", slog.String("status", status)))
} else {
w.logger.InfoContext(ctx, "stapled certificate")
w.logger.LogAttrs(ctx, slog.LevelInfo, "certificate stapled", slog.Group("ocsp",
slog.String("status", status), slog.Time("produced_at", res.ProducedAt),
slog.Time("this_update", res.ThisUpdate), slog.Time("next_update", res.NextUpdate),
slog.String("hash_algorithm", res.IssuerHash.String()),
slog.String("signature_algorithm", res.SignatureAlgorithm.String()),
))
}
return cert, nil

return nil
}

// configureFsWatcher configures the fswatcher to watch the given paths. If any
Expand Down Expand Up @@ -278,7 +337,7 @@ func (w *Watcher) forPaths(ctx context.Context, fn func(string) error, paths ...
if err := fn(f); err != nil {
watchErr = err
// We want to keep trying, so don't return the error.
return false, nil //nolint:nilerr
return false, nil
}
// We've successfully done what we needed to do with the path,
// remove it from the set.
Expand All @@ -297,7 +356,7 @@ func (w *Watcher) forPaths(ctx context.Context, fn func(string) error, paths ...
// certificate when necessary.
func (w *Watcher) Start(ctx context.Context) {
if w.fsWatcher == nil {
panic("certwatcher: fsWatcher is nil")
w.logger.LogAttrs(ctx, slog.LevelError, "filesystem watcher is not configured, unable to start certwatcher")
return
}

Expand All @@ -312,6 +371,12 @@ func (w *Watcher) Start(ctx context.Context) {
// watch watches for incoming events from fsnotify and passes them off to
// handleEvent.
func (w *Watcher) watch(ctx context.Context) {
// If stapling is enabled, start a go-routine that will re-staple the
// certificate.
if !w.options.DontStaple {
go w.waitForOCSPRefresh(ctx)
}

for {
select {
case <-ctx.Done():
Expand All @@ -330,6 +395,87 @@ func (w *Watcher) watch(ctx context.Context) {
}
}

// waitForOCSPRefresh waits for OCSP refreshes so we can keep the certificate's
// OCSP stapling up-to-date.
func (w *Watcher) waitForOCSPRefresh(ctx context.Context) {
refreshAt := ocsp.RefreshTime(w.ocsp).Sub(time.Now())

t := time.NewTimer(refreshAt)
defer func() {
if !t.Stop() {
<-t.C
}
}()

for {
select {
case <-ctx.Done():
return
case <-t.C:
// Refresh the OCSP stapling on the certificate and reset the timer.
t.Reset(w.refreshOCSP(ctx))
}
}
}

// refreshOCSP refreshes the OCSP stapling for the actively loaded certificate.
func (w *Watcher) refreshOCSP(ctx context.Context) time.Duration {
// Lock the certificate so nothing else tries to refresh or reconfigure it.
w.certμ.Lock()
defer w.certμ.Unlock()

// Increment the staple total counter.
w.stapleTotalCounter.Add(ctx, 1)

// Check when the next OCSP refresh is.
refreshTime := ocsp.RefreshTime(w.ocsp)

// Check if we need to refresh the OCSP staple on the certificate.
now := time.Now()
if !now.Before(refreshTime) {
return refreshTime.Sub(now)
}

// Clone the currently loaded certificate.
//
// We need to clone the certificate to avoid a race condition with
// OCSPStaple. That's the entire point of using an [atomic.Pointer] for
// cert, it allows us to return a TLS certificate to users, and swap in a
// new one without affecting the old certificate. If we modify the
// certificate in-place, we might as well just remove the [atomic.Pointer]
// and pray we don't update the certificate while it's being used.
currentCert := w.cert.Load()
cert := &tls.Certificate{
Certificate: currentCert.Certificate,
PrivateKey: currentCert.PrivateKey,
SupportedSignatureAlgorithms: currentCert.SupportedSignatureAlgorithms,
SignedCertificateTimestamps: currentCert.SignedCertificateTimestamps,
Leaf: currentCert.Leaf,
// OCSPStaple is intentionally omitted here.
}

// Staple the cloned certificate.
if err := w.staple(ctx, cert); err != nil {
w.logger.LogAttrs(ctx, slog.LevelWarn, "failed to re-staple certificate", slog.Any("err", err))

// Increment the staple error counter.
w.stapleErrorCounter.Add(ctx, 1)

// Return a fixed duration here so we can retry later.
//
// We could also use a backoff if we wanted better control, but OCSP
// staples usually last multiple hours, so a retrying after a little
// while (even multiple times) should be more than sufficient.
return 30 * time.Second
}

// Store the newly cloned and stapled certificate.
w.cert.Store(cert)

// Return the next refresh time.
return ocsp.RefreshTime(w.ocsp).Sub(time.Now())
}

// handleEvent handles incoming fsnotify events to detect when we need to reload
// the watched certificate files.
func (w *Watcher) handleEvent(ctx context.Context, event fsnotify.Event) {
Expand Down
Loading

0 comments on commit b669a7f

Please sign in to comment.