token_provider.go (7163B)
1 package imds 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "github.com/aws/aws-sdk-go-v2/aws" 8 "github.com/aws/smithy-go" 9 "github.com/aws/smithy-go/logging" 10 "net/http" 11 "sync" 12 "sync/atomic" 13 "time" 14 15 "github.com/aws/smithy-go/middleware" 16 smithyhttp "github.com/aws/smithy-go/transport/http" 17 ) 18 19 const ( 20 // Headers for Token and TTL 21 tokenHeader = "x-aws-ec2-metadata-token" 22 defaultTokenTTL = 5 * time.Minute 23 ) 24 25 type tokenProvider struct { 26 client *Client 27 tokenTTL time.Duration 28 29 token *apiToken 30 tokenMux sync.RWMutex 31 32 disabled uint32 // Atomic updated 33 } 34 35 func newTokenProvider(client *Client, ttl time.Duration) *tokenProvider { 36 return &tokenProvider{ 37 client: client, 38 tokenTTL: ttl, 39 } 40 } 41 42 // apiToken provides the API token used by all operation calls for th EC2 43 // Instance metadata service. 44 type apiToken struct { 45 token string 46 expires time.Time 47 } 48 49 var timeNow = time.Now 50 51 // Expired returns if the token is expired. 52 func (t *apiToken) Expired() bool { 53 // Calling Round(0) on the current time will truncate the monotonic reading only. Ensures credential expiry 54 // time is always based on reported wall-clock time. 55 return timeNow().Round(0).After(t.expires) 56 } 57 58 func (t *tokenProvider) ID() string { return "APITokenProvider" } 59 60 // HandleFinalize is the finalize stack middleware, that if the token provider is 61 // enabled, will attempt to add the cached API token to the request. If the API 62 // token is not cached, it will be retrieved in a separate API call, getToken. 63 // 64 // For retry attempts, handler must be added after attempt retryer. 65 // 66 // If request for getToken fails the token provider may be disabled from future 67 // requests, depending on the response status code. 68 func (t *tokenProvider) HandleFinalize( 69 ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler, 70 ) ( 71 out middleware.FinalizeOutput, metadata middleware.Metadata, err error, 72 ) { 73 if t.fallbackEnabled() && !t.enabled() { 74 // short-circuits to insecure data flow if token provider is disabled. 75 return next.HandleFinalize(ctx, input) 76 } 77 78 req, ok := input.Request.(*smithyhttp.Request) 79 if !ok { 80 return out, metadata, fmt.Errorf("unexpected transport request type %T", input.Request) 81 } 82 83 tok, err := t.getToken(ctx) 84 if err != nil { 85 // If the error allows the token to downgrade to insecure flow allow that. 86 var bypassErr *bypassTokenRetrievalError 87 if errors.As(err, &bypassErr) { 88 return next.HandleFinalize(ctx, input) 89 } 90 91 return out, metadata, fmt.Errorf("failed to get API token, %w", err) 92 } 93 94 req.Header.Set(tokenHeader, tok.token) 95 96 return next.HandleFinalize(ctx, input) 97 } 98 99 // HandleDeserialize is the deserialize stack middleware for determining if the 100 // operation the token provider is decorating failed because of a 401 101 // unauthorized status code. If the operation failed for that reason the token 102 // provider needs to be re-enabled so that it can start adding the API token to 103 // operation calls. 104 func (t *tokenProvider) HandleDeserialize( 105 ctx context.Context, input middleware.DeserializeInput, next middleware.DeserializeHandler, 106 ) ( 107 out middleware.DeserializeOutput, metadata middleware.Metadata, err error, 108 ) { 109 out, metadata, err = next.HandleDeserialize(ctx, input) 110 if err == nil { 111 return out, metadata, err 112 } 113 114 resp, ok := out.RawResponse.(*smithyhttp.Response) 115 if !ok { 116 return out, metadata, fmt.Errorf("expect HTTP transport, got %T", out.RawResponse) 117 } 118 119 if resp.StatusCode == http.StatusUnauthorized { // unauthorized 120 t.enable() 121 err = &retryableError{Err: err, isRetryable: true} 122 } 123 124 return out, metadata, err 125 } 126 127 func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) { 128 if t.fallbackEnabled() && !t.enabled() { 129 return nil, &bypassTokenRetrievalError{ 130 Err: fmt.Errorf("cannot get API token, provider disabled"), 131 } 132 } 133 134 t.tokenMux.RLock() 135 tok = t.token 136 t.tokenMux.RUnlock() 137 138 if tok != nil && !tok.Expired() { 139 return tok, nil 140 } 141 142 tok, err = t.updateToken(ctx) 143 if err != nil { 144 return nil, err 145 } 146 147 return tok, nil 148 } 149 150 func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) { 151 t.tokenMux.Lock() 152 defer t.tokenMux.Unlock() 153 154 // Prevent multiple requests to update retrieving the token. 155 if t.token != nil && !t.token.Expired() { 156 tok := t.token 157 return tok, nil 158 } 159 160 result, err := t.client.getToken(ctx, &getTokenInput{ 161 TokenTTL: t.tokenTTL, 162 }) 163 if err != nil { 164 var statusErr interface{ HTTPStatusCode() int } 165 if errors.As(err, &statusErr) { 166 switch statusErr.HTTPStatusCode() { 167 // Disable future get token if failed because of 403, 404, or 405 168 case http.StatusForbidden, 169 http.StatusNotFound, 170 http.StatusMethodNotAllowed: 171 172 if t.fallbackEnabled() { 173 logger := middleware.GetLogger(ctx) 174 logger.Logf(logging.Warn, "falling back to IMDSv1: %v", err) 175 t.disable() 176 } 177 178 // 400 errors are terminal, and need to be upstreamed 179 case http.StatusBadRequest: 180 return nil, err 181 } 182 } 183 184 // Disable if request send failed or timed out getting response 185 var re *smithyhttp.RequestSendError 186 var ce *smithy.CanceledError 187 if errors.As(err, &re) || errors.As(err, &ce) { 188 atomic.StoreUint32(&t.disabled, 1) 189 } 190 191 if !t.fallbackEnabled() { 192 // NOTE: getToken() is an implementation detail of some outer operation 193 // (e.g. GetMetadata). It has its own retries that have already been exhausted. 194 // Mark the underlying error as a terminal error. 195 err = &retryableError{Err: err, isRetryable: false} 196 return nil, err 197 } 198 199 // Token couldn't be retrieved, fallback to IMDSv1 insecure flow for this request 200 // and allow the request to proceed. Future requests _may_ re-attempt fetching a 201 // token if not disabled. 202 return nil, &bypassTokenRetrievalError{Err: err} 203 } 204 205 tok := &apiToken{ 206 token: result.Token, 207 expires: timeNow().Add(result.TokenTTL), 208 } 209 t.token = tok 210 211 return tok, nil 212 } 213 214 // enabled returns if the token provider is current enabled or not. 215 func (t *tokenProvider) enabled() bool { 216 return atomic.LoadUint32(&t.disabled) == 0 217 } 218 219 // fallbackEnabled returns false if EnableFallback is [aws.FalseTernary], true otherwise 220 func (t *tokenProvider) fallbackEnabled() bool { 221 switch t.client.options.EnableFallback { 222 case aws.FalseTernary: 223 return false 224 default: 225 return true 226 } 227 } 228 229 // disable disables the token provider and it will no longer attempt to inject 230 // the token, nor request updates. 231 func (t *tokenProvider) disable() { 232 atomic.StoreUint32(&t.disabled, 1) 233 } 234 235 // enable enables the token provide to start refreshing tokens, and adding them 236 // to the pending request. 237 func (t *tokenProvider) enable() { 238 t.tokenMux.Lock() 239 t.token = nil 240 t.tokenMux.Unlock() 241 atomic.StoreUint32(&t.disabled, 0) 242 } 243 244 type bypassTokenRetrievalError struct { 245 Err error 246 } 247 248 func (e *bypassTokenRetrievalError) Error() string { 249 return fmt.Sprintf("bypass token retrieval, %v", e.Err) 250 } 251 252 func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err } 253 254 type retryableError struct { 255 Err error 256 isRetryable bool 257 } 258 259 func (e *retryableError) RetryableError() bool { return e.isRetryable } 260 261 func (e *retryableError) Error() string { return e.Err.Error() }