token_provider.go (6310B)
1 package imds 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net/http" 8 "sync" 9 "sync/atomic" 10 "time" 11 12 smithy "github.com/aws/smithy-go" 13 "github.com/aws/smithy-go/middleware" 14 smithyhttp "github.com/aws/smithy-go/transport/http" 15 ) 16 17 const ( 18 // Headers for Token and TTL 19 tokenHeader = "x-aws-ec2-metadata-token" 20 defaultTokenTTL = 5 * time.Minute 21 ) 22 23 type tokenProvider struct { 24 client *Client 25 tokenTTL time.Duration 26 27 token *apiToken 28 tokenMux sync.RWMutex 29 30 disabled uint32 // Atomic updated 31 } 32 33 func newTokenProvider(client *Client, ttl time.Duration) *tokenProvider { 34 return &tokenProvider{ 35 client: client, 36 tokenTTL: ttl, 37 } 38 } 39 40 // apiToken provides the API token used by all operation calls for th EC2 41 // Instance metadata service. 42 type apiToken struct { 43 token string 44 expires time.Time 45 } 46 47 var timeNow = time.Now 48 49 // Expired returns if the token is expired. 50 func (t *apiToken) Expired() bool { 51 // Calling Round(0) on the current time will truncate the monotonic reading only. Ensures credential expiry 52 // time is always based on reported wall-clock time. 53 return timeNow().Round(0).After(t.expires) 54 } 55 56 func (t *tokenProvider) ID() string { return "APITokenProvider" } 57 58 // HandleFinalize is the finalize stack middleware, that if the token provider is 59 // enabled, will attempt to add the cached API token to the request. If the API 60 // token is not cached, it will be retrieved in a separate API call, getToken. 61 // 62 // For retry attempts, handler must be added after attempt retryer. 63 // 64 // If request for getToken fails the token provider may be disabled from future 65 // requests, depending on the response status code. 66 func (t *tokenProvider) HandleFinalize( 67 ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler, 68 ) ( 69 out middleware.FinalizeOutput, metadata middleware.Metadata, err error, 70 ) { 71 if !t.enabled() { 72 // short-circuits to insecure data flow if token provider is disabled. 73 return next.HandleFinalize(ctx, input) 74 } 75 76 req, ok := input.Request.(*smithyhttp.Request) 77 if !ok { 78 return out, metadata, fmt.Errorf("unexpected transport request type %T", input.Request) 79 } 80 81 tok, err := t.getToken(ctx) 82 if err != nil { 83 // If the error allows the token to downgrade to insecure flow allow that. 84 var bypassErr *bypassTokenRetrievalError 85 if errors.As(err, &bypassErr) { 86 return next.HandleFinalize(ctx, input) 87 } 88 89 return out, metadata, fmt.Errorf("failed to get API token, %w", err) 90 } 91 92 req.Header.Set(tokenHeader, tok.token) 93 94 return next.HandleFinalize(ctx, input) 95 } 96 97 // HandleDeserialize is the deserialize stack middleware for determining if the 98 // operation the token provider is decorating failed because of a 401 99 // unauthorized status code. If the operation failed for that reason the token 100 // provider needs to be re-enabled so that it can start adding the API token to 101 // operation calls. 102 func (t *tokenProvider) HandleDeserialize( 103 ctx context.Context, input middleware.DeserializeInput, next middleware.DeserializeHandler, 104 ) ( 105 out middleware.DeserializeOutput, metadata middleware.Metadata, err error, 106 ) { 107 out, metadata, err = next.HandleDeserialize(ctx, input) 108 if err == nil { 109 return out, metadata, err 110 } 111 112 resp, ok := out.RawResponse.(*smithyhttp.Response) 113 if !ok { 114 return out, metadata, fmt.Errorf("expect HTTP transport, got %T", out.RawResponse) 115 } 116 117 if resp.StatusCode == http.StatusUnauthorized { // unauthorized 118 err = &retryableError{Err: err} 119 t.enable() 120 } 121 122 return out, metadata, err 123 } 124 125 type retryableError struct { 126 Err error 127 } 128 129 func (*retryableError) RetryableError() bool { return true } 130 131 func (e *retryableError) Error() string { return e.Err.Error() } 132 133 func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) { 134 if !t.enabled() { 135 return nil, &bypassTokenRetrievalError{ 136 Err: fmt.Errorf("cannot get API token, provider disabled"), 137 } 138 } 139 140 t.tokenMux.RLock() 141 tok = t.token 142 t.tokenMux.RUnlock() 143 144 if tok != nil && !tok.Expired() { 145 return tok, nil 146 } 147 148 tok, err = t.updateToken(ctx) 149 if err != nil { 150 return nil, fmt.Errorf("cannot get API token, %w", err) 151 } 152 153 return tok, nil 154 } 155 156 func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) { 157 t.tokenMux.Lock() 158 defer t.tokenMux.Unlock() 159 160 // Prevent multiple requests to update retrieving the token. 161 if t.token != nil && !t.token.Expired() { 162 tok := t.token 163 return tok, nil 164 } 165 166 result, err := t.client.getToken(ctx, &getTokenInput{ 167 TokenTTL: t.tokenTTL, 168 }) 169 if err != nil { 170 // change the disabled flag on token provider to true, when error is request timeout error. 171 var statusErr interface{ HTTPStatusCode() int } 172 if errors.As(err, &statusErr) { 173 switch statusErr.HTTPStatusCode() { 174 175 // Disable get token if failed because of 403, 404, or 405 176 case http.StatusForbidden, 177 http.StatusNotFound, 178 http.StatusMethodNotAllowed: 179 180 t.disable() 181 182 // 400 errors are terminal, and need to be upstreamed 183 case http.StatusBadRequest: 184 return nil, err 185 } 186 } 187 188 // Disable if request send failed or timed out getting response 189 var re *smithyhttp.RequestSendError 190 var ce *smithy.CanceledError 191 if errors.As(err, &re) || errors.As(err, &ce) { 192 atomic.StoreUint32(&t.disabled, 1) 193 } 194 195 // Token couldn't be retrieved, but bypass this, and allow the 196 // request to continue. 197 return nil, &bypassTokenRetrievalError{Err: err} 198 } 199 200 tok := &apiToken{ 201 token: result.Token, 202 expires: timeNow().Add(result.TokenTTL), 203 } 204 t.token = tok 205 206 return tok, nil 207 } 208 209 type bypassTokenRetrievalError struct { 210 Err error 211 } 212 213 func (e *bypassTokenRetrievalError) Error() string { 214 return fmt.Sprintf("bypass token retrieval, %v", e.Err) 215 } 216 217 func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err } 218 219 // enabled returns if the token provider is current enabled or not. 220 func (t *tokenProvider) enabled() bool { 221 return atomic.LoadUint32(&t.disabled) == 0 222 } 223 224 // disable disables the token provider and it will no longer attempt to inject 225 // the token, nor request updates. 226 func (t *tokenProvider) disable() { 227 atomic.StoreUint32(&t.disabled, 1) 228 } 229 230 // enable enables the token provide to start refreshing tokens, and adding them 231 // to the pending request. 232 func (t *tokenProvider) enable() { 233 t.tokenMux.Lock() 234 t.token = nil 235 t.tokenMux.Unlock() 236 atomic.StoreUint32(&t.disabled, 0) 237 }