src

Go monorepo.
git clone git://code.dwrz.net/src
Log | Files | Refs

token_provider.go (7163B)


      1 package imds
      2 
      3 import (
      4 	"context"
      5 	"errors"
      6 	"fmt"
      7 	"github.com/aws/aws-sdk-go-v2/aws"
      8 	"github.com/aws/smithy-go"
      9 	"github.com/aws/smithy-go/logging"
     10 	"net/http"
     11 	"sync"
     12 	"sync/atomic"
     13 	"time"
     14 
     15 	"github.com/aws/smithy-go/middleware"
     16 	smithyhttp "github.com/aws/smithy-go/transport/http"
     17 )
     18 
     19 const (
     20 	// Headers for Token and TTL
     21 	tokenHeader     = "x-aws-ec2-metadata-token"
     22 	defaultTokenTTL = 5 * time.Minute
     23 )
     24 
     25 type tokenProvider struct {
     26 	client   *Client
     27 	tokenTTL time.Duration
     28 
     29 	token    *apiToken
     30 	tokenMux sync.RWMutex
     31 
     32 	disabled uint32 // Atomic updated
     33 }
     34 
     35 func newTokenProvider(client *Client, ttl time.Duration) *tokenProvider {
     36 	return &tokenProvider{
     37 		client:   client,
     38 		tokenTTL: ttl,
     39 	}
     40 }
     41 
     42 // apiToken provides the API token used by all operation calls for th EC2
     43 // Instance metadata service.
     44 type apiToken struct {
     45 	token   string
     46 	expires time.Time
     47 }
     48 
     49 var timeNow = time.Now
     50 
     51 // Expired returns if the token is expired.
     52 func (t *apiToken) Expired() bool {
     53 	// Calling Round(0) on the current time will truncate the monotonic reading only. Ensures credential expiry
     54 	// time is always based on reported wall-clock time.
     55 	return timeNow().Round(0).After(t.expires)
     56 }
     57 
     58 func (t *tokenProvider) ID() string { return "APITokenProvider" }
     59 
     60 // HandleFinalize is the finalize stack middleware, that if the token provider is
     61 // enabled, will attempt to add the cached API token to the request. If the API
     62 // token is not cached, it will be retrieved in a separate API call, getToken.
     63 //
     64 // For retry attempts, handler must be added after attempt retryer.
     65 //
     66 // If request for getToken fails the token provider may be disabled from future
     67 // requests, depending on the response status code.
     68 func (t *tokenProvider) HandleFinalize(
     69 	ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler,
     70 ) (
     71 	out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
     72 ) {
     73 	if t.fallbackEnabled() && !t.enabled() {
     74 		// short-circuits to insecure data flow if token provider is disabled.
     75 		return next.HandleFinalize(ctx, input)
     76 	}
     77 
     78 	req, ok := input.Request.(*smithyhttp.Request)
     79 	if !ok {
     80 		return out, metadata, fmt.Errorf("unexpected transport request type %T", input.Request)
     81 	}
     82 
     83 	tok, err := t.getToken(ctx)
     84 	if err != nil {
     85 		// If the error allows the token to downgrade to insecure flow allow that.
     86 		var bypassErr *bypassTokenRetrievalError
     87 		if errors.As(err, &bypassErr) {
     88 			return next.HandleFinalize(ctx, input)
     89 		}
     90 
     91 		return out, metadata, fmt.Errorf("failed to get API token, %w", err)
     92 	}
     93 
     94 	req.Header.Set(tokenHeader, tok.token)
     95 
     96 	return next.HandleFinalize(ctx, input)
     97 }
     98 
     99 // HandleDeserialize is the deserialize stack middleware for determining if the
    100 // operation the token provider is decorating failed because of a 401
    101 // unauthorized status code. If the operation failed for that reason the token
    102 // provider needs to be re-enabled so that it can start adding the API token to
    103 // operation calls.
    104 func (t *tokenProvider) HandleDeserialize(
    105 	ctx context.Context, input middleware.DeserializeInput, next middleware.DeserializeHandler,
    106 ) (
    107 	out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
    108 ) {
    109 	out, metadata, err = next.HandleDeserialize(ctx, input)
    110 	if err == nil {
    111 		return out, metadata, err
    112 	}
    113 
    114 	resp, ok := out.RawResponse.(*smithyhttp.Response)
    115 	if !ok {
    116 		return out, metadata, fmt.Errorf("expect HTTP transport, got %T", out.RawResponse)
    117 	}
    118 
    119 	if resp.StatusCode == http.StatusUnauthorized { // unauthorized
    120 		t.enable()
    121 		err = &retryableError{Err: err, isRetryable: true}
    122 	}
    123 
    124 	return out, metadata, err
    125 }
    126 
    127 func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) {
    128 	if t.fallbackEnabled() && !t.enabled() {
    129 		return nil, &bypassTokenRetrievalError{
    130 			Err: fmt.Errorf("cannot get API token, provider disabled"),
    131 		}
    132 	}
    133 
    134 	t.tokenMux.RLock()
    135 	tok = t.token
    136 	t.tokenMux.RUnlock()
    137 
    138 	if tok != nil && !tok.Expired() {
    139 		return tok, nil
    140 	}
    141 
    142 	tok, err = t.updateToken(ctx)
    143 	if err != nil {
    144 		return nil, err
    145 	}
    146 
    147 	return tok, nil
    148 }
    149 
    150 func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
    151 	t.tokenMux.Lock()
    152 	defer t.tokenMux.Unlock()
    153 
    154 	// Prevent multiple requests to update retrieving the token.
    155 	if t.token != nil && !t.token.Expired() {
    156 		tok := t.token
    157 		return tok, nil
    158 	}
    159 
    160 	result, err := t.client.getToken(ctx, &getTokenInput{
    161 		TokenTTL: t.tokenTTL,
    162 	})
    163 	if err != nil {
    164 		var statusErr interface{ HTTPStatusCode() int }
    165 		if errors.As(err, &statusErr) {
    166 			switch statusErr.HTTPStatusCode() {
    167 			// Disable future get token if failed because of 403, 404, or 405
    168 			case http.StatusForbidden,
    169 				http.StatusNotFound,
    170 				http.StatusMethodNotAllowed:
    171 
    172 				if t.fallbackEnabled() {
    173 					logger := middleware.GetLogger(ctx)
    174 					logger.Logf(logging.Warn, "falling back to IMDSv1: %v", err)
    175 					t.disable()
    176 				}
    177 
    178 			// 400 errors are terminal, and need to be upstreamed
    179 			case http.StatusBadRequest:
    180 				return nil, err
    181 			}
    182 		}
    183 
    184 		// Disable if request send failed or timed out getting response
    185 		var re *smithyhttp.RequestSendError
    186 		var ce *smithy.CanceledError
    187 		if errors.As(err, &re) || errors.As(err, &ce) {
    188 			atomic.StoreUint32(&t.disabled, 1)
    189 		}
    190 
    191 		if !t.fallbackEnabled() {
    192 			// NOTE: getToken() is an implementation detail of some outer operation
    193 			// (e.g. GetMetadata). It has its own retries that have already been exhausted.
    194 			// Mark the underlying error as a terminal error.
    195 			err = &retryableError{Err: err, isRetryable: false}
    196 			return nil, err
    197 		}
    198 
    199 		// Token couldn't be retrieved, fallback to IMDSv1 insecure flow for this request
    200 		// and allow the request to proceed. Future requests _may_ re-attempt fetching a
    201 		// token if not disabled.
    202 		return nil, &bypassTokenRetrievalError{Err: err}
    203 	}
    204 
    205 	tok := &apiToken{
    206 		token:   result.Token,
    207 		expires: timeNow().Add(result.TokenTTL),
    208 	}
    209 	t.token = tok
    210 
    211 	return tok, nil
    212 }
    213 
    214 // enabled returns if the token provider is current enabled or not.
    215 func (t *tokenProvider) enabled() bool {
    216 	return atomic.LoadUint32(&t.disabled) == 0
    217 }
    218 
    219 // fallbackEnabled returns false if EnableFallback is [aws.FalseTernary], true otherwise
    220 func (t *tokenProvider) fallbackEnabled() bool {
    221 	switch t.client.options.EnableFallback {
    222 	case aws.FalseTernary:
    223 		return false
    224 	default:
    225 		return true
    226 	}
    227 }
    228 
    229 // disable disables the token provider and it will no longer attempt to inject
    230 // the token, nor request updates.
    231 func (t *tokenProvider) disable() {
    232 	atomic.StoreUint32(&t.disabled, 1)
    233 }
    234 
    235 // enable enables the token provide to start refreshing tokens, and adding them
    236 // to the pending request.
    237 func (t *tokenProvider) enable() {
    238 	t.tokenMux.Lock()
    239 	t.token = nil
    240 	t.tokenMux.Unlock()
    241 	atomic.StoreUint32(&t.disabled, 0)
    242 }
    243 
    244 type bypassTokenRetrievalError struct {
    245 	Err error
    246 }
    247 
    248 func (e *bypassTokenRetrievalError) Error() string {
    249 	return fmt.Sprintf("bypass token retrieval, %v", e.Err)
    250 }
    251 
    252 func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err }
    253 
    254 type retryableError struct {
    255 	Err         error
    256 	isRetryable bool
    257 }
    258 
    259 func (e *retryableError) RetryableError() bool { return e.isRetryable }
    260 
    261 func (e *retryableError) Error() string { return e.Err.Error() }