src

Go monorepo.
git clone git://code.dwrz.net/src
Log | Files | Refs

token_cache.go (7234B)


      1 package bearer
      2 
      3 import (
      4 	"context"
      5 	"fmt"
      6 	"sync/atomic"
      7 	"time"
      8 
      9 	smithycontext "github.com/aws/smithy-go/context"
     10 	"github.com/aws/smithy-go/internal/sync/singleflight"
     11 )
     12 
     13 // package variable that can be override in unit tests.
     14 var timeNow = time.Now
     15 
     16 // TokenCacheOptions provides a set of optional configuration options for the
     17 // TokenCache TokenProvider.
     18 type TokenCacheOptions struct {
     19 	// The duration before the token will expire when the credentials will be
     20 	// refreshed. If DisableAsyncRefresh is true, the RetrieveBearerToken calls
     21 	// will be blocking.
     22 	//
     23 	// Asynchronous refreshes are deduplicated, and only one will be in-flight
     24 	// at a time. If the token expires while an asynchronous refresh is in
     25 	// flight, the next call to RetrieveBearerToken will block on that refresh
     26 	// to return.
     27 	RefreshBeforeExpires time.Duration
     28 
     29 	// The timeout the underlying TokenProvider's RetrieveBearerToken call must
     30 	// return within, or will be canceled. Defaults to 0, no timeout.
     31 	//
     32 	// If 0 timeout, its possible for the underlying tokenProvider's
     33 	// RetrieveBearerToken call to block forever. Preventing subsequent
     34 	// TokenCache attempts to refresh the token.
     35 	//
     36 	// If this timeout is reached all pending deduplicated calls to
     37 	// TokenCache RetrieveBearerToken will fail with an error.
     38 	RetrieveBearerTokenTimeout time.Duration
     39 
     40 	// The minimum duration between asynchronous refresh attempts. If the next
     41 	// asynchronous recent refresh attempt was within the minimum delay
     42 	// duration, the call to retrieve will return the current cached token, if
     43 	// not expired.
     44 	//
     45 	// The asynchronous retrieve is deduplicated across multiple calls when
     46 	// RetrieveBearerToken is called. The asynchronous retrieve is not a
     47 	// periodic task. It is only performed when the token has not yet expired,
     48 	// and the current item is within the RefreshBeforeExpires window, and the
     49 	// TokenCache's RetrieveBearerToken method is called.
     50 	//
     51 	// If 0, (default) there will be no minimum delay between asynchronous
     52 	// refresh attempts.
     53 	//
     54 	// If DisableAsyncRefresh is true, this option is ignored.
     55 	AsyncRefreshMinimumDelay time.Duration
     56 
     57 	// Sets if the TokenCache will attempt to refresh the token in the
     58 	// background asynchronously instead of blocking for credentials to be
     59 	// refreshed. If disabled token refresh will be blocking.
     60 	//
     61 	// The first call to RetrieveBearerToken will always be blocking, because
     62 	// there is no cached token.
     63 	DisableAsyncRefresh bool
     64 }
     65 
     66 // TokenCache provides an utility to cache Bearer Authentication tokens from a
     67 // wrapped TokenProvider. The TokenCache can be has options to configure the
     68 // cache's early and asynchronous refresh of the token.
     69 type TokenCache struct {
     70 	options  TokenCacheOptions
     71 	provider TokenProvider
     72 
     73 	cachedToken            atomic.Value
     74 	lastRefreshAttemptTime atomic.Value
     75 	sfGroup                singleflight.Group
     76 }
     77 
     78 // NewTokenCache returns a initialized TokenCache that implements the
     79 // TokenProvider interface. Wrapping the provider passed in. Also taking a set
     80 // of optional functional option parameters to configure the token cache.
     81 func NewTokenCache(provider TokenProvider, optFns ...func(*TokenCacheOptions)) *TokenCache {
     82 	var options TokenCacheOptions
     83 	for _, fn := range optFns {
     84 		fn(&options)
     85 	}
     86 
     87 	return &TokenCache{
     88 		options:  options,
     89 		provider: provider,
     90 	}
     91 }
     92 
     93 // RetrieveBearerToken returns the token if it could be obtained, or error if a
     94 // valid token could not be retrieved.
     95 //
     96 // The passed in Context's cancel/deadline/timeout will impacting only this
     97 // individual retrieve call and not any other already queued up calls. This
     98 // means underlying provider's RetrieveBearerToken calls could block for ever,
     99 // and not be canceled with the Context. Set RetrieveBearerTokenTimeout to
    100 // provide a timeout, preventing the underlying TokenProvider blocking forever.
    101 //
    102 // By default, if the passed in Context is canceled, all of its values will be
    103 // considered expired. The wrapped TokenProvider will not be able to lookup the
    104 // values from the Context once it is expired. This is done to protect against
    105 // expired values no longer being valid. To disable this behavior, use
    106 // smithy-go's context.WithPreserveExpiredValues to add a value to the Context
    107 // before calling RetrieveBearerToken to enable support for expired values.
    108 //
    109 // Without RetrieveBearerTokenTimeout there is the potential for a underlying
    110 // Provider's RetrieveBearerToken call to sit forever. Blocking in subsequent
    111 // attempts at refreshing the token.
    112 func (p *TokenCache) RetrieveBearerToken(ctx context.Context) (Token, error) {
    113 	cachedToken, ok := p.getCachedToken()
    114 	if !ok || cachedToken.Expired(timeNow()) {
    115 		return p.refreshBearerToken(ctx)
    116 	}
    117 
    118 	// Check if the token should be refreshed before it expires.
    119 	refreshToken := cachedToken.Expired(timeNow().Add(p.options.RefreshBeforeExpires))
    120 	if !refreshToken {
    121 		return cachedToken, nil
    122 	}
    123 
    124 	if p.options.DisableAsyncRefresh {
    125 		return p.refreshBearerToken(ctx)
    126 	}
    127 
    128 	p.tryAsyncRefresh(ctx)
    129 
    130 	return cachedToken, nil
    131 }
    132 
    133 // tryAsyncRefresh attempts to asynchronously refresh the token returning the
    134 // already cached token. If it AsyncRefreshMinimumDelay option is not zero, and
    135 // the duration since the last refresh is less than that value, nothing will be
    136 // done.
    137 func (p *TokenCache) tryAsyncRefresh(ctx context.Context) {
    138 	if p.options.AsyncRefreshMinimumDelay != 0 {
    139 		var lastRefreshAttempt time.Time
    140 		if v := p.lastRefreshAttemptTime.Load(); v != nil {
    141 			lastRefreshAttempt = v.(time.Time)
    142 		}
    143 
    144 		if timeNow().Before(lastRefreshAttempt.Add(p.options.AsyncRefreshMinimumDelay)) {
    145 			return
    146 		}
    147 	}
    148 
    149 	// Ignore the returned channel so this won't be blocking, and limit the
    150 	// number of additional goroutines created.
    151 	p.sfGroup.DoChan("async-refresh", func() (interface{}, error) {
    152 		res, err := p.refreshBearerToken(ctx)
    153 		if p.options.AsyncRefreshMinimumDelay != 0 {
    154 			var refreshAttempt time.Time
    155 			if err != nil {
    156 				refreshAttempt = timeNow()
    157 			}
    158 			p.lastRefreshAttemptTime.Store(refreshAttempt)
    159 		}
    160 
    161 		return res, err
    162 	})
    163 }
    164 
    165 func (p *TokenCache) refreshBearerToken(ctx context.Context) (Token, error) {
    166 	resCh := p.sfGroup.DoChan("refresh-token", func() (interface{}, error) {
    167 		ctx := smithycontext.WithSuppressCancel(ctx)
    168 		if v := p.options.RetrieveBearerTokenTimeout; v != 0 {
    169 			var cancel func()
    170 			ctx, cancel = context.WithTimeout(ctx, v)
    171 			defer cancel()
    172 		}
    173 		return p.singleRetrieve(ctx)
    174 	})
    175 
    176 	select {
    177 	case res := <-resCh:
    178 		return res.Val.(Token), res.Err
    179 	case <-ctx.Done():
    180 		return Token{}, fmt.Errorf("retrieve bearer token canceled, %w", ctx.Err())
    181 	}
    182 }
    183 
    184 func (p *TokenCache) singleRetrieve(ctx context.Context) (interface{}, error) {
    185 	token, err := p.provider.RetrieveBearerToken(ctx)
    186 	if err != nil {
    187 		return Token{}, fmt.Errorf("failed to retrieve bearer token, %w", err)
    188 	}
    189 
    190 	p.cachedToken.Store(&token)
    191 	return token, nil
    192 }
    193 
    194 // getCachedToken returns the currently cached token and true if found. Returns
    195 // false if no token is cached.
    196 func (p *TokenCache) getCachedToken() (Token, bool) {
    197 	v := p.cachedToken.Load()
    198 	if v == nil {
    199 		return Token{}, false
    200 	}
    201 
    202 	t := v.(*Token)
    203 	if t == nil || t.Value == "" {
    204 		return Token{}, false
    205 	}
    206 
    207 	return *t, true
    208 }