code.dwrz.net

Go monorepo.
Log | Files | Refs

token_provider.go (6310B)


      1 package imds
      2 
      3 import (
      4 	"context"
      5 	"errors"
      6 	"fmt"
      7 	"net/http"
      8 	"sync"
      9 	"sync/atomic"
     10 	"time"
     11 
     12 	smithy "github.com/aws/smithy-go"
     13 	"github.com/aws/smithy-go/middleware"
     14 	smithyhttp "github.com/aws/smithy-go/transport/http"
     15 )
     16 
     17 const (
     18 	// Headers for Token and TTL
     19 	tokenHeader     = "x-aws-ec2-metadata-token"
     20 	defaultTokenTTL = 5 * time.Minute
     21 )
     22 
     23 type tokenProvider struct {
     24 	client   *Client
     25 	tokenTTL time.Duration
     26 
     27 	token    *apiToken
     28 	tokenMux sync.RWMutex
     29 
     30 	disabled uint32 // Atomic updated
     31 }
     32 
     33 func newTokenProvider(client *Client, ttl time.Duration) *tokenProvider {
     34 	return &tokenProvider{
     35 		client:   client,
     36 		tokenTTL: ttl,
     37 	}
     38 }
     39 
     40 // apiToken provides the API token used by all operation calls for th EC2
     41 // Instance metadata service.
     42 type apiToken struct {
     43 	token   string
     44 	expires time.Time
     45 }
     46 
     47 var timeNow = time.Now
     48 
     49 // Expired returns if the token is expired.
     50 func (t *apiToken) Expired() bool {
     51 	// Calling Round(0) on the current time will truncate the monotonic reading only. Ensures credential expiry
     52 	// time is always based on reported wall-clock time.
     53 	return timeNow().Round(0).After(t.expires)
     54 }
     55 
     56 func (t *tokenProvider) ID() string { return "APITokenProvider" }
     57 
     58 // HandleFinalize is the finalize stack middleware, that if the token provider is
     59 // enabled, will attempt to add the cached API token to the request. If the API
     60 // token is not cached, it will be retrieved in a separate API call, getToken.
     61 //
     62 // For retry attempts, handler must be added after attempt retryer.
     63 //
     64 // If request for getToken fails the token provider may be disabled from future
     65 // requests, depending on the response status code.
     66 func (t *tokenProvider) HandleFinalize(
     67 	ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler,
     68 ) (
     69 	out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
     70 ) {
     71 	if !t.enabled() {
     72 		// short-circuits to insecure data flow if token provider is disabled.
     73 		return next.HandleFinalize(ctx, input)
     74 	}
     75 
     76 	req, ok := input.Request.(*smithyhttp.Request)
     77 	if !ok {
     78 		return out, metadata, fmt.Errorf("unexpected transport request type %T", input.Request)
     79 	}
     80 
     81 	tok, err := t.getToken(ctx)
     82 	if err != nil {
     83 		// If the error allows the token to downgrade to insecure flow allow that.
     84 		var bypassErr *bypassTokenRetrievalError
     85 		if errors.As(err, &bypassErr) {
     86 			return next.HandleFinalize(ctx, input)
     87 		}
     88 
     89 		return out, metadata, fmt.Errorf("failed to get API token, %w", err)
     90 	}
     91 
     92 	req.Header.Set(tokenHeader, tok.token)
     93 
     94 	return next.HandleFinalize(ctx, input)
     95 }
     96 
     97 // HandleDeserialize is the deserialize stack middleware for determining if the
     98 // operation the token provider is decorating failed because of a 401
     99 // unauthorized status code. If the operation failed for that reason the token
    100 // provider needs to be re-enabled so that it can start adding the API token to
    101 // operation calls.
    102 func (t *tokenProvider) HandleDeserialize(
    103 	ctx context.Context, input middleware.DeserializeInput, next middleware.DeserializeHandler,
    104 ) (
    105 	out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
    106 ) {
    107 	out, metadata, err = next.HandleDeserialize(ctx, input)
    108 	if err == nil {
    109 		return out, metadata, err
    110 	}
    111 
    112 	resp, ok := out.RawResponse.(*smithyhttp.Response)
    113 	if !ok {
    114 		return out, metadata, fmt.Errorf("expect HTTP transport, got %T", out.RawResponse)
    115 	}
    116 
    117 	if resp.StatusCode == http.StatusUnauthorized { // unauthorized
    118 		err = &retryableError{Err: err}
    119 		t.enable()
    120 	}
    121 
    122 	return out, metadata, err
    123 }
    124 
    125 type retryableError struct {
    126 	Err error
    127 }
    128 
    129 func (*retryableError) RetryableError() bool { return true }
    130 
    131 func (e *retryableError) Error() string { return e.Err.Error() }
    132 
    133 func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) {
    134 	if !t.enabled() {
    135 		return nil, &bypassTokenRetrievalError{
    136 			Err: fmt.Errorf("cannot get API token, provider disabled"),
    137 		}
    138 	}
    139 
    140 	t.tokenMux.RLock()
    141 	tok = t.token
    142 	t.tokenMux.RUnlock()
    143 
    144 	if tok != nil && !tok.Expired() {
    145 		return tok, nil
    146 	}
    147 
    148 	tok, err = t.updateToken(ctx)
    149 	if err != nil {
    150 		return nil, fmt.Errorf("cannot get API token, %w", err)
    151 	}
    152 
    153 	return tok, nil
    154 }
    155 
    156 func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
    157 	t.tokenMux.Lock()
    158 	defer t.tokenMux.Unlock()
    159 
    160 	// Prevent multiple requests to update retrieving the token.
    161 	if t.token != nil && !t.token.Expired() {
    162 		tok := t.token
    163 		return tok, nil
    164 	}
    165 
    166 	result, err := t.client.getToken(ctx, &getTokenInput{
    167 		TokenTTL: t.tokenTTL,
    168 	})
    169 	if err != nil {
    170 		// change the disabled flag on token provider to true, when error is request timeout error.
    171 		var statusErr interface{ HTTPStatusCode() int }
    172 		if errors.As(err, &statusErr) {
    173 			switch statusErr.HTTPStatusCode() {
    174 
    175 			// Disable get token if failed because of 403, 404, or 405
    176 			case http.StatusForbidden,
    177 				http.StatusNotFound,
    178 				http.StatusMethodNotAllowed:
    179 
    180 				t.disable()
    181 
    182 			// 400 errors are terminal, and need to be upstreamed
    183 			case http.StatusBadRequest:
    184 				return nil, err
    185 			}
    186 		}
    187 
    188 		// Disable if request send failed or timed out getting response
    189 		var re *smithyhttp.RequestSendError
    190 		var ce *smithy.CanceledError
    191 		if errors.As(err, &re) || errors.As(err, &ce) {
    192 			atomic.StoreUint32(&t.disabled, 1)
    193 		}
    194 
    195 		// Token couldn't be retrieved, but bypass this, and allow the
    196 		// request to continue.
    197 		return nil, &bypassTokenRetrievalError{Err: err}
    198 	}
    199 
    200 	tok := &apiToken{
    201 		token:   result.Token,
    202 		expires: timeNow().Add(result.TokenTTL),
    203 	}
    204 	t.token = tok
    205 
    206 	return tok, nil
    207 }
    208 
    209 type bypassTokenRetrievalError struct {
    210 	Err error
    211 }
    212 
    213 func (e *bypassTokenRetrievalError) Error() string {
    214 	return fmt.Sprintf("bypass token retrieval, %v", e.Err)
    215 }
    216 
    217 func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err }
    218 
    219 // enabled returns if the token provider is current enabled or not.
    220 func (t *tokenProvider) enabled() bool {
    221 	return atomic.LoadUint32(&t.disabled) == 0
    222 }
    223 
    224 // disable disables the token provider and it will no longer attempt to inject
    225 // the token, nor request updates.
    226 func (t *tokenProvider) disable() {
    227 	atomic.StoreUint32(&t.disabled, 1)
    228 }
    229 
    230 // enable enables the token provide to start refreshing tokens, and adding them
    231 // to the pending request.
    232 func (t *tokenProvider) enable() {
    233 	t.tokenMux.Lock()
    234 	t.token = nil
    235 	t.tokenMux.Unlock()
    236 	atomic.StoreUint32(&t.disabled, 0)
    237 }