src

Go monorepo.
git clone git://code.dwrz.net/src
Log | Files | Refs

request_middleware.go (8734B)


      1 package imds
      2 
      3 import (
      4 	"bytes"
      5 	"context"
      6 	"fmt"
      7 	"io/ioutil"
      8 	"net/url"
      9 	"path"
     10 	"time"
     11 
     12 	awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
     13 	"github.com/aws/aws-sdk-go-v2/aws/retry"
     14 	"github.com/aws/smithy-go/middleware"
     15 	smithyhttp "github.com/aws/smithy-go/transport/http"
     16 )
     17 
     18 func addAPIRequestMiddleware(stack *middleware.Stack,
     19 	options Options,
     20 	operation string,
     21 	getPath func(interface{}) (string, error),
     22 	getOutput func(*smithyhttp.Response) (interface{}, error),
     23 ) (err error) {
     24 	err = addRequestMiddleware(stack, options, "GET", operation, getPath, getOutput)
     25 	if err != nil {
     26 		return err
     27 	}
     28 
     29 	// Token Serializer build and state management.
     30 	if !options.disableAPIToken {
     31 		err = stack.Finalize.Insert(options.tokenProvider, (*retry.Attempt)(nil).ID(), middleware.After)
     32 		if err != nil {
     33 			return err
     34 		}
     35 
     36 		err = stack.Deserialize.Insert(options.tokenProvider, "OperationDeserializer", middleware.Before)
     37 		if err != nil {
     38 			return err
     39 		}
     40 	}
     41 
     42 	return nil
     43 }
     44 
     45 func addRequestMiddleware(stack *middleware.Stack,
     46 	options Options,
     47 	method string,
     48 	operation string,
     49 	getPath func(interface{}) (string, error),
     50 	getOutput func(*smithyhttp.Response) (interface{}, error),
     51 ) (err error) {
     52 	err = awsmiddleware.AddSDKAgentKey(awsmiddleware.FeatureMetadata, "ec2-imds")(stack)
     53 	if err != nil {
     54 		return err
     55 	}
     56 
     57 	// Operation timeout
     58 	err = stack.Initialize.Add(&operationTimeout{
     59 		Disabled:       options.DisableDefaultTimeout,
     60 		DefaultTimeout: defaultOperationTimeout,
     61 	}, middleware.Before)
     62 	if err != nil {
     63 		return err
     64 	}
     65 
     66 	// Operation Serializer
     67 	err = stack.Serialize.Add(&serializeRequest{
     68 		GetPath: getPath,
     69 		Method:  method,
     70 	}, middleware.After)
     71 	if err != nil {
     72 		return err
     73 	}
     74 
     75 	// Operation endpoint resolver
     76 	err = stack.Serialize.Insert(&resolveEndpoint{
     77 		Endpoint:     options.Endpoint,
     78 		EndpointMode: options.EndpointMode,
     79 	}, "OperationSerializer", middleware.Before)
     80 	if err != nil {
     81 		return err
     82 	}
     83 
     84 	// Operation Deserializer
     85 	err = stack.Deserialize.Add(&deserializeResponse{
     86 		GetOutput: getOutput,
     87 	}, middleware.After)
     88 	if err != nil {
     89 		return err
     90 	}
     91 
     92 	err = stack.Deserialize.Add(&smithyhttp.RequestResponseLogger{
     93 		LogRequest:          options.ClientLogMode.IsRequest(),
     94 		LogRequestWithBody:  options.ClientLogMode.IsRequestWithBody(),
     95 		LogResponse:         options.ClientLogMode.IsResponse(),
     96 		LogResponseWithBody: options.ClientLogMode.IsResponseWithBody(),
     97 	}, middleware.After)
     98 	if err != nil {
     99 		return err
    100 	}
    101 
    102 	err = addSetLoggerMiddleware(stack, options)
    103 	if err != nil {
    104 		return err
    105 	}
    106 
    107 	if err := addProtocolFinalizerMiddlewares(stack, options, operation); err != nil {
    108 		return fmt.Errorf("add protocol finalizers: %w", err)
    109 	}
    110 
    111 	// Retry support
    112 	return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{
    113 		Retryer:          options.Retryer,
    114 		LogRetryAttempts: options.ClientLogMode.IsRetries(),
    115 	})
    116 }
    117 
    118 func addSetLoggerMiddleware(stack *middleware.Stack, o Options) error {
    119 	return middleware.AddSetLoggerMiddleware(stack, o.Logger)
    120 }
    121 
    122 type serializeRequest struct {
    123 	GetPath func(interface{}) (string, error)
    124 	Method  string
    125 }
    126 
    127 func (*serializeRequest) ID() string {
    128 	return "OperationSerializer"
    129 }
    130 
    131 func (m *serializeRequest) HandleSerialize(
    132 	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
    133 ) (
    134 	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
    135 ) {
    136 	request, ok := in.Request.(*smithyhttp.Request)
    137 	if !ok {
    138 		return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
    139 	}
    140 
    141 	reqPath, err := m.GetPath(in.Parameters)
    142 	if err != nil {
    143 		return out, metadata, fmt.Errorf("unable to get request URL path, %w", err)
    144 	}
    145 
    146 	request.Request.URL.Path = reqPath
    147 	request.Request.Method = m.Method
    148 
    149 	return next.HandleSerialize(ctx, in)
    150 }
    151 
    152 type deserializeResponse struct {
    153 	GetOutput func(*smithyhttp.Response) (interface{}, error)
    154 }
    155 
    156 func (*deserializeResponse) ID() string {
    157 	return "OperationDeserializer"
    158 }
    159 
    160 func (m *deserializeResponse) HandleDeserialize(
    161 	ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
    162 ) (
    163 	out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
    164 ) {
    165 	out, metadata, err = next.HandleDeserialize(ctx, in)
    166 	if err != nil {
    167 		return out, metadata, err
    168 	}
    169 
    170 	resp, ok := out.RawResponse.(*smithyhttp.Response)
    171 	if !ok {
    172 		return out, metadata, fmt.Errorf(
    173 			"unexpected transport response type, %T, want %T", out.RawResponse, resp)
    174 	}
    175 	defer resp.Body.Close()
    176 
    177 	// read the full body so that any operation timeouts cleanup will not race
    178 	// the body being read.
    179 	body, err := ioutil.ReadAll(resp.Body)
    180 	if err != nil {
    181 		return out, metadata, fmt.Errorf("read response body failed, %w", err)
    182 	}
    183 	resp.Body = ioutil.NopCloser(bytes.NewReader(body))
    184 
    185 	// Anything that's not 200 |< 300 is error
    186 	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
    187 		return out, metadata, &smithyhttp.ResponseError{
    188 			Response: resp,
    189 			Err:      fmt.Errorf("request to EC2 IMDS failed"),
    190 		}
    191 	}
    192 
    193 	result, err := m.GetOutput(resp)
    194 	if err != nil {
    195 		return out, metadata, fmt.Errorf(
    196 			"unable to get deserialized result for response, %w", err,
    197 		)
    198 	}
    199 	out.Result = result
    200 
    201 	return out, metadata, err
    202 }
    203 
    204 type resolveEndpoint struct {
    205 	Endpoint     string
    206 	EndpointMode EndpointModeState
    207 }
    208 
    209 func (*resolveEndpoint) ID() string {
    210 	return "ResolveEndpoint"
    211 }
    212 
    213 func (m *resolveEndpoint) HandleSerialize(
    214 	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
    215 ) (
    216 	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
    217 ) {
    218 
    219 	req, ok := in.Request.(*smithyhttp.Request)
    220 	if !ok {
    221 		return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
    222 	}
    223 
    224 	var endpoint string
    225 	if len(m.Endpoint) > 0 {
    226 		endpoint = m.Endpoint
    227 	} else {
    228 		switch m.EndpointMode {
    229 		case EndpointModeStateIPv6:
    230 			endpoint = defaultIPv6Endpoint
    231 		case EndpointModeStateIPv4:
    232 			fallthrough
    233 		case EndpointModeStateUnset:
    234 			endpoint = defaultIPv4Endpoint
    235 		default:
    236 			return out, metadata, fmt.Errorf("unsupported IMDS endpoint mode")
    237 		}
    238 	}
    239 
    240 	req.URL, err = url.Parse(endpoint)
    241 	if err != nil {
    242 		return out, metadata, fmt.Errorf("failed to parse endpoint URL: %w", err)
    243 	}
    244 
    245 	return next.HandleSerialize(ctx, in)
    246 }
    247 
    248 const (
    249 	defaultOperationTimeout = 5 * time.Second
    250 )
    251 
    252 // operationTimeout adds a timeout on the middleware stack if the Context the
    253 // stack was called with does not have a deadline. The next middleware must
    254 // complete before the timeout, or the context will be canceled.
    255 //
    256 // If DefaultTimeout is zero, no default timeout will be used if the Context
    257 // does not have a timeout.
    258 //
    259 // The next middleware must also ensure that any resources that are also
    260 // canceled by the stack's context are completely consumed before returning.
    261 // Otherwise the timeout cleanup will race the resource being consumed
    262 // upstream.
    263 type operationTimeout struct {
    264 	Disabled       bool
    265 	DefaultTimeout time.Duration
    266 }
    267 
    268 func (*operationTimeout) ID() string { return "OperationTimeout" }
    269 
    270 func (m *operationTimeout) HandleInitialize(
    271 	ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler,
    272 ) (
    273 	output middleware.InitializeOutput, metadata middleware.Metadata, err error,
    274 ) {
    275 	if m.Disabled {
    276 		return next.HandleInitialize(ctx, input)
    277 	}
    278 
    279 	if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 {
    280 		var cancelFn func()
    281 		ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout)
    282 		defer cancelFn()
    283 	}
    284 
    285 	return next.HandleInitialize(ctx, input)
    286 }
    287 
    288 // appendURIPath joins a URI path component to the existing path with `/`
    289 // separators between the path components. If the path being added ends with a
    290 // trailing `/` that slash will be maintained.
    291 func appendURIPath(base, add string) string {
    292 	reqPath := path.Join(base, add)
    293 	if len(add) != 0 && add[len(add)-1] == '/' {
    294 		reqPath += "/"
    295 	}
    296 	return reqPath
    297 }
    298 
    299 func addProtocolFinalizerMiddlewares(stack *middleware.Stack, options Options, operation string) error {
    300 	if err := stack.Finalize.Add(&resolveAuthSchemeMiddleware{operation: operation, options: options}, middleware.Before); err != nil {
    301 		return fmt.Errorf("add ResolveAuthScheme: %w", err)
    302 	}
    303 	if err := stack.Finalize.Insert(&getIdentityMiddleware{options: options}, "ResolveAuthScheme", middleware.After); err != nil {
    304 		return fmt.Errorf("add GetIdentity: %w", err)
    305 	}
    306 	if err := stack.Finalize.Insert(&resolveEndpointV2Middleware{options: options}, "GetIdentity", middleware.After); err != nil {
    307 		return fmt.Errorf("add ResolveEndpointV2: %w", err)
    308 	}
    309 	if err := stack.Finalize.Insert(&signRequestMiddleware{}, "ResolveEndpointV2", middleware.After); err != nil {
    310 		return fmt.Errorf("add Signing: %w", err)
    311 	}
    312 	return nil
    313 }