credential_cache.go (7945B)
1 package aws 2 3 import ( 4 "context" 5 "fmt" 6 "sync/atomic" 7 "time" 8 9 sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand" 10 "github.com/aws/aws-sdk-go-v2/internal/sync/singleflight" 11 ) 12 13 // CredentialsCacheOptions are the options 14 type CredentialsCacheOptions struct { 15 16 // ExpiryWindow will allow the credentials to trigger refreshing prior to 17 // the credentials actually expiring. This is beneficial so race conditions 18 // with expiring credentials do not cause request to fail unexpectedly 19 // due to ExpiredTokenException exceptions. 20 // 21 // An ExpiryWindow of 10s would cause calls to IsExpired() to return true 22 // 10 seconds before the credentials are actually expired. This can cause an 23 // increased number of requests to refresh the credentials to occur. 24 // 25 // If ExpiryWindow is 0 or less it will be ignored. 26 ExpiryWindow time.Duration 27 28 // ExpiryWindowJitterFrac provides a mechanism for randomizing the 29 // expiration of credentials within the configured ExpiryWindow by a random 30 // percentage. Valid values are between 0.0 and 1.0. 31 // 32 // As an example if ExpiryWindow is 60 seconds and ExpiryWindowJitterFrac 33 // is 0.5 then credentials will be set to expire between 30 to 60 seconds 34 // prior to their actual expiration time. 35 // 36 // If ExpiryWindow is 0 or less then ExpiryWindowJitterFrac is ignored. 37 // If ExpiryWindowJitterFrac is 0 then no randomization will be applied to the window. 38 // If ExpiryWindowJitterFrac < 0 the value will be treated as 0. 39 // If ExpiryWindowJitterFrac > 1 the value will be treated as 1. 40 ExpiryWindowJitterFrac float64 41 } 42 43 // CredentialsCache provides caching and concurrency safe credentials retrieval 44 // via the provider's retrieve method. 45 // 46 // CredentialsCache will look for optional interfaces on the Provider to adjust 47 // how the credential cache handles credentials caching. 48 // 49 // - HandleFailRefreshCredentialsCacheStrategy - Allows provider to handle 50 // credential refresh failures. This could return an updated Credentials 51 // value, or attempt another means of retrieving credentials. 52 // 53 // - AdjustExpiresByCredentialsCacheStrategy - Allows provider to adjust how 54 // credentials Expires is modified. This could modify how the Credentials 55 // Expires is adjusted based on the CredentialsCache ExpiryWindow option. 56 // Such as providing a floor not to reduce the Expires below. 57 type CredentialsCache struct { 58 provider CredentialsProvider 59 60 options CredentialsCacheOptions 61 creds atomic.Value 62 sf singleflight.Group 63 } 64 65 // NewCredentialsCache returns a CredentialsCache that wraps provider. Provider 66 // is expected to not be nil. A variadic list of one or more functions can be 67 // provided to modify the CredentialsCache configuration. This allows for 68 // configuration of credential expiry window and jitter. 69 func NewCredentialsCache(provider CredentialsProvider, optFns ...func(options *CredentialsCacheOptions)) *CredentialsCache { 70 options := CredentialsCacheOptions{} 71 72 for _, fn := range optFns { 73 fn(&options) 74 } 75 76 if options.ExpiryWindow < 0 { 77 options.ExpiryWindow = 0 78 } 79 80 if options.ExpiryWindowJitterFrac < 0 { 81 options.ExpiryWindowJitterFrac = 0 82 } else if options.ExpiryWindowJitterFrac > 1 { 83 options.ExpiryWindowJitterFrac = 1 84 } 85 86 return &CredentialsCache{ 87 provider: provider, 88 options: options, 89 } 90 } 91 92 // Retrieve returns the credentials. If the credentials have already been 93 // retrieved, and not expired the cached credentials will be returned. If the 94 // credentials have not been retrieved yet, or expired the provider's Retrieve 95 // method will be called. 96 // 97 // Returns and error if the provider's retrieve method returns an error. 98 func (p *CredentialsCache) Retrieve(ctx context.Context) (Credentials, error) { 99 if creds, ok := p.getCreds(); ok && !creds.Expired() { 100 return creds, nil 101 } 102 103 resCh := p.sf.DoChan("", func() (interface{}, error) { 104 return p.singleRetrieve(&suppressedContext{ctx}) 105 }) 106 select { 107 case res := <-resCh: 108 return res.Val.(Credentials), res.Err 109 case <-ctx.Done(): 110 return Credentials{}, &RequestCanceledError{Err: ctx.Err()} 111 } 112 } 113 114 func (p *CredentialsCache) singleRetrieve(ctx context.Context) (interface{}, error) { 115 currCreds, ok := p.getCreds() 116 if ok && !currCreds.Expired() { 117 return currCreds, nil 118 } 119 120 newCreds, err := p.provider.Retrieve(ctx) 121 if err != nil { 122 handleFailToRefresh := defaultHandleFailToRefresh 123 if cs, ok := p.provider.(HandleFailRefreshCredentialsCacheStrategy); ok { 124 handleFailToRefresh = cs.HandleFailToRefresh 125 } 126 newCreds, err = handleFailToRefresh(ctx, currCreds, err) 127 if err != nil { 128 return Credentials{}, fmt.Errorf("failed to refresh cached credentials, %w", err) 129 } 130 } 131 132 if newCreds.CanExpire && p.options.ExpiryWindow > 0 { 133 adjustExpiresBy := defaultAdjustExpiresBy 134 if cs, ok := p.provider.(AdjustExpiresByCredentialsCacheStrategy); ok { 135 adjustExpiresBy = cs.AdjustExpiresBy 136 } 137 138 randFloat64, err := sdkrand.CryptoRandFloat64() 139 if err != nil { 140 return Credentials{}, fmt.Errorf("failed to get random provider, %w", err) 141 } 142 143 var jitter time.Duration 144 if p.options.ExpiryWindowJitterFrac > 0 { 145 jitter = time.Duration(randFloat64 * 146 p.options.ExpiryWindowJitterFrac * float64(p.options.ExpiryWindow)) 147 } 148 149 newCreds, err = adjustExpiresBy(newCreds, -(p.options.ExpiryWindow - jitter)) 150 if err != nil { 151 return Credentials{}, fmt.Errorf("failed to adjust credentials expires, %w", err) 152 } 153 } 154 155 p.creds.Store(&newCreds) 156 return newCreds, nil 157 } 158 159 // getCreds returns the currently stored credentials and true. Returning false 160 // if no credentials were stored. 161 func (p *CredentialsCache) getCreds() (Credentials, bool) { 162 v := p.creds.Load() 163 if v == nil { 164 return Credentials{}, false 165 } 166 167 c := v.(*Credentials) 168 if c == nil || !c.HasKeys() { 169 return Credentials{}, false 170 } 171 172 return *c, true 173 } 174 175 // Invalidate will invalidate the cached credentials. The next call to Retrieve 176 // will cause the provider's Retrieve method to be called. 177 func (p *CredentialsCache) Invalidate() { 178 p.creds.Store((*Credentials)(nil)) 179 } 180 181 // IsCredentialsProvider returns whether credential provider wrapped by CredentialsCache 182 // matches the target provider type. 183 func (p *CredentialsCache) IsCredentialsProvider(target CredentialsProvider) bool { 184 return IsCredentialsProvider(p.provider, target) 185 } 186 187 // HandleFailRefreshCredentialsCacheStrategy is an interface for 188 // CredentialsCache to allow CredentialsProvider how failed to refresh 189 // credentials is handled. 190 type HandleFailRefreshCredentialsCacheStrategy interface { 191 // Given the previously cached Credentials, if any, and refresh error, may 192 // returns new or modified set of Credentials, or error. 193 // 194 // Credential caches may use default implementation if nil. 195 HandleFailToRefresh(context.Context, Credentials, error) (Credentials, error) 196 } 197 198 // defaultHandleFailToRefresh returns the passed in error. 199 func defaultHandleFailToRefresh(ctx context.Context, _ Credentials, err error) (Credentials, error) { 200 return Credentials{}, err 201 } 202 203 // AdjustExpiresByCredentialsCacheStrategy is an interface for CredentialCache 204 // to allow CredentialsProvider to intercept adjustments to Credentials expiry 205 // based on expectations and use cases of CredentialsProvider. 206 // 207 // Credential caches may use default implementation if nil. 208 type AdjustExpiresByCredentialsCacheStrategy interface { 209 // Given a Credentials as input, applying any mutations and 210 // returning the potentially updated Credentials, or error. 211 AdjustExpiresBy(Credentials, time.Duration) (Credentials, error) 212 } 213 214 // defaultAdjustExpiresBy adds the duration to the passed in credentials Expires, 215 // and returns the updated credentials value. If Credentials value's CanExpire 216 // is false, the passed in credentials are returned unchanged. 217 func defaultAdjustExpiresBy(creds Credentials, dur time.Duration) (Credentials, error) { 218 if !creds.CanExpire { 219 return creds, nil 220 } 221 222 creds.Expires = creds.Expires.Add(dur) 223 return creds, nil 224 }