src

Go monorepo.
git clone git://code.dwrz.net/src
Log | Files | Refs

credential_cache.go (7945B)


      1 package aws
      2 
      3 import (
      4 	"context"
      5 	"fmt"
      6 	"sync/atomic"
      7 	"time"
      8 
      9 	sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
     10 	"github.com/aws/aws-sdk-go-v2/internal/sync/singleflight"
     11 )
     12 
     13 // CredentialsCacheOptions are the options
     14 type CredentialsCacheOptions struct {
     15 
     16 	// ExpiryWindow will allow the credentials to trigger refreshing prior to
     17 	// the credentials actually expiring. This is beneficial so race conditions
     18 	// with expiring credentials do not cause request to fail unexpectedly
     19 	// due to ExpiredTokenException exceptions.
     20 	//
     21 	// An ExpiryWindow of 10s would cause calls to IsExpired() to return true
     22 	// 10 seconds before the credentials are actually expired. This can cause an
     23 	// increased number of requests to refresh the credentials to occur.
     24 	//
     25 	// If ExpiryWindow is 0 or less it will be ignored.
     26 	ExpiryWindow time.Duration
     27 
     28 	// ExpiryWindowJitterFrac provides a mechanism for randomizing the
     29 	// expiration of credentials within the configured ExpiryWindow by a random
     30 	// percentage. Valid values are between 0.0 and 1.0.
     31 	//
     32 	// As an example if ExpiryWindow is 60 seconds and ExpiryWindowJitterFrac
     33 	// is 0.5 then credentials will be set to expire between 30 to 60 seconds
     34 	// prior to their actual expiration time.
     35 	//
     36 	// If ExpiryWindow is 0 or less then ExpiryWindowJitterFrac is ignored.
     37 	// If ExpiryWindowJitterFrac is 0 then no randomization will be applied to the window.
     38 	// If ExpiryWindowJitterFrac < 0 the value will be treated as 0.
     39 	// If ExpiryWindowJitterFrac > 1 the value will be treated as 1.
     40 	ExpiryWindowJitterFrac float64
     41 }
     42 
     43 // CredentialsCache provides caching and concurrency safe credentials retrieval
     44 // via the provider's retrieve method.
     45 //
     46 // CredentialsCache will look for optional interfaces on the Provider to adjust
     47 // how the credential cache handles credentials caching.
     48 //
     49 //   - HandleFailRefreshCredentialsCacheStrategy - Allows provider to handle
     50 //     credential refresh failures. This could return an updated Credentials
     51 //     value, or attempt another means of retrieving credentials.
     52 //
     53 //   - AdjustExpiresByCredentialsCacheStrategy - Allows provider to adjust how
     54 //     credentials Expires is modified. This could modify how the Credentials
     55 //     Expires is adjusted based on the CredentialsCache ExpiryWindow option.
     56 //     Such as providing a floor not to reduce the Expires below.
     57 type CredentialsCache struct {
     58 	provider CredentialsProvider
     59 
     60 	options CredentialsCacheOptions
     61 	creds   atomic.Value
     62 	sf      singleflight.Group
     63 }
     64 
     65 // NewCredentialsCache returns a CredentialsCache that wraps provider. Provider
     66 // is expected to not be nil. A variadic list of one or more functions can be
     67 // provided to modify the CredentialsCache configuration. This allows for
     68 // configuration of credential expiry window and jitter.
     69 func NewCredentialsCache(provider CredentialsProvider, optFns ...func(options *CredentialsCacheOptions)) *CredentialsCache {
     70 	options := CredentialsCacheOptions{}
     71 
     72 	for _, fn := range optFns {
     73 		fn(&options)
     74 	}
     75 
     76 	if options.ExpiryWindow < 0 {
     77 		options.ExpiryWindow = 0
     78 	}
     79 
     80 	if options.ExpiryWindowJitterFrac < 0 {
     81 		options.ExpiryWindowJitterFrac = 0
     82 	} else if options.ExpiryWindowJitterFrac > 1 {
     83 		options.ExpiryWindowJitterFrac = 1
     84 	}
     85 
     86 	return &CredentialsCache{
     87 		provider: provider,
     88 		options:  options,
     89 	}
     90 }
     91 
     92 // Retrieve returns the credentials. If the credentials have already been
     93 // retrieved, and not expired the cached credentials will be returned. If the
     94 // credentials have not been retrieved yet, or expired the provider's Retrieve
     95 // method will be called.
     96 //
     97 // Returns and error if the provider's retrieve method returns an error.
     98 func (p *CredentialsCache) Retrieve(ctx context.Context) (Credentials, error) {
     99 	if creds, ok := p.getCreds(); ok && !creds.Expired() {
    100 		return creds, nil
    101 	}
    102 
    103 	resCh := p.sf.DoChan("", func() (interface{}, error) {
    104 		return p.singleRetrieve(&suppressedContext{ctx})
    105 	})
    106 	select {
    107 	case res := <-resCh:
    108 		return res.Val.(Credentials), res.Err
    109 	case <-ctx.Done():
    110 		return Credentials{}, &RequestCanceledError{Err: ctx.Err()}
    111 	}
    112 }
    113 
    114 func (p *CredentialsCache) singleRetrieve(ctx context.Context) (interface{}, error) {
    115 	currCreds, ok := p.getCreds()
    116 	if ok && !currCreds.Expired() {
    117 		return currCreds, nil
    118 	}
    119 
    120 	newCreds, err := p.provider.Retrieve(ctx)
    121 	if err != nil {
    122 		handleFailToRefresh := defaultHandleFailToRefresh
    123 		if cs, ok := p.provider.(HandleFailRefreshCredentialsCacheStrategy); ok {
    124 			handleFailToRefresh = cs.HandleFailToRefresh
    125 		}
    126 		newCreds, err = handleFailToRefresh(ctx, currCreds, err)
    127 		if err != nil {
    128 			return Credentials{}, fmt.Errorf("failed to refresh cached credentials, %w", err)
    129 		}
    130 	}
    131 
    132 	if newCreds.CanExpire && p.options.ExpiryWindow > 0 {
    133 		adjustExpiresBy := defaultAdjustExpiresBy
    134 		if cs, ok := p.provider.(AdjustExpiresByCredentialsCacheStrategy); ok {
    135 			adjustExpiresBy = cs.AdjustExpiresBy
    136 		}
    137 
    138 		randFloat64, err := sdkrand.CryptoRandFloat64()
    139 		if err != nil {
    140 			return Credentials{}, fmt.Errorf("failed to get random provider, %w", err)
    141 		}
    142 
    143 		var jitter time.Duration
    144 		if p.options.ExpiryWindowJitterFrac > 0 {
    145 			jitter = time.Duration(randFloat64 *
    146 				p.options.ExpiryWindowJitterFrac * float64(p.options.ExpiryWindow))
    147 		}
    148 
    149 		newCreds, err = adjustExpiresBy(newCreds, -(p.options.ExpiryWindow - jitter))
    150 		if err != nil {
    151 			return Credentials{}, fmt.Errorf("failed to adjust credentials expires, %w", err)
    152 		}
    153 	}
    154 
    155 	p.creds.Store(&newCreds)
    156 	return newCreds, nil
    157 }
    158 
    159 // getCreds returns the currently stored credentials and true. Returning false
    160 // if no credentials were stored.
    161 func (p *CredentialsCache) getCreds() (Credentials, bool) {
    162 	v := p.creds.Load()
    163 	if v == nil {
    164 		return Credentials{}, false
    165 	}
    166 
    167 	c := v.(*Credentials)
    168 	if c == nil || !c.HasKeys() {
    169 		return Credentials{}, false
    170 	}
    171 
    172 	return *c, true
    173 }
    174 
    175 // Invalidate will invalidate the cached credentials. The next call to Retrieve
    176 // will cause the provider's Retrieve method to be called.
    177 func (p *CredentialsCache) Invalidate() {
    178 	p.creds.Store((*Credentials)(nil))
    179 }
    180 
    181 // IsCredentialsProvider returns whether credential provider wrapped by CredentialsCache
    182 // matches the target provider type.
    183 func (p *CredentialsCache) IsCredentialsProvider(target CredentialsProvider) bool {
    184 	return IsCredentialsProvider(p.provider, target)
    185 }
    186 
    187 // HandleFailRefreshCredentialsCacheStrategy is an interface for
    188 // CredentialsCache to allow CredentialsProvider  how failed to refresh
    189 // credentials is handled.
    190 type HandleFailRefreshCredentialsCacheStrategy interface {
    191 	// Given the previously cached Credentials, if any, and refresh error, may
    192 	// returns new or modified set of Credentials, or error.
    193 	//
    194 	// Credential caches may use default implementation if nil.
    195 	HandleFailToRefresh(context.Context, Credentials, error) (Credentials, error)
    196 }
    197 
    198 // defaultHandleFailToRefresh returns the passed in error.
    199 func defaultHandleFailToRefresh(ctx context.Context, _ Credentials, err error) (Credentials, error) {
    200 	return Credentials{}, err
    201 }
    202 
    203 // AdjustExpiresByCredentialsCacheStrategy is an interface for CredentialCache
    204 // to allow CredentialsProvider to intercept adjustments to Credentials expiry
    205 // based on expectations and use cases of CredentialsProvider.
    206 //
    207 // Credential caches may use default implementation if nil.
    208 type AdjustExpiresByCredentialsCacheStrategy interface {
    209 	// Given a Credentials as input, applying any mutations and
    210 	// returning the potentially updated Credentials, or error.
    211 	AdjustExpiresBy(Credentials, time.Duration) (Credentials, error)
    212 }
    213 
    214 // defaultAdjustExpiresBy adds the duration to the passed in credentials Expires,
    215 // and returns the updated credentials value. If Credentials value's CanExpire
    216 // is false, the passed in credentials are returned unchanged.
    217 func defaultAdjustExpiresBy(creds Credentials, dur time.Duration) (Credentials, error) {
    218 	if !creds.CanExpire {
    219 		return creds, nil
    220 	}
    221 
    222 	creds.Expires = creds.Expires.Add(dur)
    223 	return creds, nil
    224 }