src

Go monorepo.
git clone git://code.dwrz.net/src
Log | Files | Refs

request_middleware.go (8600B)


      1 package imds
      2 
      3 import (
      4 	"bytes"
      5 	"context"
      6 	"fmt"
      7 	"io/ioutil"
      8 	"net/url"
      9 	"path"
     10 	"time"
     11 
     12 	awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
     13 	"github.com/aws/aws-sdk-go-v2/aws/retry"
     14 	"github.com/aws/smithy-go/middleware"
     15 	smithyhttp "github.com/aws/smithy-go/transport/http"
     16 )
     17 
     18 func addAPIRequestMiddleware(stack *middleware.Stack,
     19 	options Options,
     20 	operation string,
     21 	getPath func(interface{}) (string, error),
     22 	getOutput func(*smithyhttp.Response) (interface{}, error),
     23 ) (err error) {
     24 	err = addRequestMiddleware(stack, options, "GET", operation, getPath, getOutput)
     25 	if err != nil {
     26 		return err
     27 	}
     28 
     29 	// Token Serializer build and state management.
     30 	if !options.disableAPIToken {
     31 		err = stack.Finalize.Insert(options.tokenProvider, (*retry.Attempt)(nil).ID(), middleware.After)
     32 		if err != nil {
     33 			return err
     34 		}
     35 
     36 		err = stack.Deserialize.Insert(options.tokenProvider, "OperationDeserializer", middleware.Before)
     37 		if err != nil {
     38 			return err
     39 		}
     40 	}
     41 
     42 	return nil
     43 }
     44 
     45 func addRequestMiddleware(stack *middleware.Stack,
     46 	options Options,
     47 	method string,
     48 	operation string,
     49 	getPath func(interface{}) (string, error),
     50 	getOutput func(*smithyhttp.Response) (interface{}, error),
     51 ) (err error) {
     52 	err = awsmiddleware.AddSDKAgentKey(awsmiddleware.FeatureMetadata, "ec2-imds")(stack)
     53 	if err != nil {
     54 		return err
     55 	}
     56 
     57 	// Operation timeout
     58 	err = stack.Initialize.Add(&operationTimeout{
     59 		DefaultTimeout: defaultOperationTimeout,
     60 	}, middleware.Before)
     61 	if err != nil {
     62 		return err
     63 	}
     64 
     65 	// Operation Serializer
     66 	err = stack.Serialize.Add(&serializeRequest{
     67 		GetPath: getPath,
     68 		Method:  method,
     69 	}, middleware.After)
     70 	if err != nil {
     71 		return err
     72 	}
     73 
     74 	// Operation endpoint resolver
     75 	err = stack.Serialize.Insert(&resolveEndpoint{
     76 		Endpoint:     options.Endpoint,
     77 		EndpointMode: options.EndpointMode,
     78 	}, "OperationSerializer", middleware.Before)
     79 	if err != nil {
     80 		return err
     81 	}
     82 
     83 	// Operation Deserializer
     84 	err = stack.Deserialize.Add(&deserializeResponse{
     85 		GetOutput: getOutput,
     86 	}, middleware.After)
     87 	if err != nil {
     88 		return err
     89 	}
     90 
     91 	err = stack.Deserialize.Add(&smithyhttp.RequestResponseLogger{
     92 		LogRequest:          options.ClientLogMode.IsRequest(),
     93 		LogRequestWithBody:  options.ClientLogMode.IsRequestWithBody(),
     94 		LogResponse:         options.ClientLogMode.IsResponse(),
     95 		LogResponseWithBody: options.ClientLogMode.IsResponseWithBody(),
     96 	}, middleware.After)
     97 	if err != nil {
     98 		return err
     99 	}
    100 
    101 	err = addSetLoggerMiddleware(stack, options)
    102 	if err != nil {
    103 		return err
    104 	}
    105 
    106 	if err := addProtocolFinalizerMiddlewares(stack, options, operation); err != nil {
    107 		return fmt.Errorf("add protocol finalizers: %w", err)
    108 	}
    109 
    110 	// Retry support
    111 	return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{
    112 		Retryer:          options.Retryer,
    113 		LogRetryAttempts: options.ClientLogMode.IsRetries(),
    114 	})
    115 }
    116 
    117 func addSetLoggerMiddleware(stack *middleware.Stack, o Options) error {
    118 	return middleware.AddSetLoggerMiddleware(stack, o.Logger)
    119 }
    120 
    121 type serializeRequest struct {
    122 	GetPath func(interface{}) (string, error)
    123 	Method  string
    124 }
    125 
    126 func (*serializeRequest) ID() string {
    127 	return "OperationSerializer"
    128 }
    129 
    130 func (m *serializeRequest) HandleSerialize(
    131 	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
    132 ) (
    133 	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
    134 ) {
    135 	request, ok := in.Request.(*smithyhttp.Request)
    136 	if !ok {
    137 		return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
    138 	}
    139 
    140 	reqPath, err := m.GetPath(in.Parameters)
    141 	if err != nil {
    142 		return out, metadata, fmt.Errorf("unable to get request URL path, %w", err)
    143 	}
    144 
    145 	request.Request.URL.Path = reqPath
    146 	request.Request.Method = m.Method
    147 
    148 	return next.HandleSerialize(ctx, in)
    149 }
    150 
    151 type deserializeResponse struct {
    152 	GetOutput func(*smithyhttp.Response) (interface{}, error)
    153 }
    154 
    155 func (*deserializeResponse) ID() string {
    156 	return "OperationDeserializer"
    157 }
    158 
    159 func (m *deserializeResponse) HandleDeserialize(
    160 	ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
    161 ) (
    162 	out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
    163 ) {
    164 	out, metadata, err = next.HandleDeserialize(ctx, in)
    165 	if err != nil {
    166 		return out, metadata, err
    167 	}
    168 
    169 	resp, ok := out.RawResponse.(*smithyhttp.Response)
    170 	if !ok {
    171 		return out, metadata, fmt.Errorf(
    172 			"unexpected transport response type, %T, want %T", out.RawResponse, resp)
    173 	}
    174 	defer resp.Body.Close()
    175 
    176 	// read the full body so that any operation timeouts cleanup will not race
    177 	// the body being read.
    178 	body, err := ioutil.ReadAll(resp.Body)
    179 	if err != nil {
    180 		return out, metadata, fmt.Errorf("read response body failed, %w", err)
    181 	}
    182 	resp.Body = ioutil.NopCloser(bytes.NewReader(body))
    183 
    184 	// Anything that's not 200 |< 300 is error
    185 	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
    186 		return out, metadata, &smithyhttp.ResponseError{
    187 			Response: resp,
    188 			Err:      fmt.Errorf("request to EC2 IMDS failed"),
    189 		}
    190 	}
    191 
    192 	result, err := m.GetOutput(resp)
    193 	if err != nil {
    194 		return out, metadata, fmt.Errorf(
    195 			"unable to get deserialized result for response, %w", err,
    196 		)
    197 	}
    198 	out.Result = result
    199 
    200 	return out, metadata, err
    201 }
    202 
    203 type resolveEndpoint struct {
    204 	Endpoint     string
    205 	EndpointMode EndpointModeState
    206 }
    207 
    208 func (*resolveEndpoint) ID() string {
    209 	return "ResolveEndpoint"
    210 }
    211 
    212 func (m *resolveEndpoint) HandleSerialize(
    213 	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
    214 ) (
    215 	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
    216 ) {
    217 
    218 	req, ok := in.Request.(*smithyhttp.Request)
    219 	if !ok {
    220 		return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
    221 	}
    222 
    223 	var endpoint string
    224 	if len(m.Endpoint) > 0 {
    225 		endpoint = m.Endpoint
    226 	} else {
    227 		switch m.EndpointMode {
    228 		case EndpointModeStateIPv6:
    229 			endpoint = defaultIPv6Endpoint
    230 		case EndpointModeStateIPv4:
    231 			fallthrough
    232 		case EndpointModeStateUnset:
    233 			endpoint = defaultIPv4Endpoint
    234 		default:
    235 			return out, metadata, fmt.Errorf("unsupported IMDS endpoint mode")
    236 		}
    237 	}
    238 
    239 	req.URL, err = url.Parse(endpoint)
    240 	if err != nil {
    241 		return out, metadata, fmt.Errorf("failed to parse endpoint URL: %w", err)
    242 	}
    243 
    244 	return next.HandleSerialize(ctx, in)
    245 }
    246 
    247 const (
    248 	defaultOperationTimeout = 5 * time.Second
    249 )
    250 
    251 // operationTimeout adds a timeout on the middleware stack if the Context the
    252 // stack was called with does not have a deadline. The next middleware must
    253 // complete before the timeout, or the context will be canceled.
    254 //
    255 // If DefaultTimeout is zero, no default timeout will be used if the Context
    256 // does not have a timeout.
    257 //
    258 // The next middleware must also ensure that any resources that are also
    259 // canceled by the stack's context are completely consumed before returning.
    260 // Otherwise the timeout cleanup will race the resource being consumed
    261 // upstream.
    262 type operationTimeout struct {
    263 	DefaultTimeout time.Duration
    264 }
    265 
    266 func (*operationTimeout) ID() string { return "OperationTimeout" }
    267 
    268 func (m *operationTimeout) HandleInitialize(
    269 	ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler,
    270 ) (
    271 	output middleware.InitializeOutput, metadata middleware.Metadata, err error,
    272 ) {
    273 	if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 {
    274 		var cancelFn func()
    275 		ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout)
    276 		defer cancelFn()
    277 	}
    278 
    279 	return next.HandleInitialize(ctx, input)
    280 }
    281 
    282 // appendURIPath joins a URI path component to the existing path with `/`
    283 // separators between the path components. If the path being added ends with a
    284 // trailing `/` that slash will be maintained.
    285 func appendURIPath(base, add string) string {
    286 	reqPath := path.Join(base, add)
    287 	if len(add) != 0 && add[len(add)-1] == '/' {
    288 		reqPath += "/"
    289 	}
    290 	return reqPath
    291 }
    292 
    293 func addProtocolFinalizerMiddlewares(stack *middleware.Stack, options Options, operation string) error {
    294 	if err := stack.Finalize.Add(&resolveAuthSchemeMiddleware{operation: operation, options: options}, middleware.Before); err != nil {
    295 		return fmt.Errorf("add ResolveAuthScheme: %w", err)
    296 	}
    297 	if err := stack.Finalize.Insert(&getIdentityMiddleware{options: options}, "ResolveAuthScheme", middleware.After); err != nil {
    298 		return fmt.Errorf("add GetIdentity: %w", err)
    299 	}
    300 	if err := stack.Finalize.Insert(&resolveEndpointV2Middleware{options: options}, "GetIdentity", middleware.After); err != nil {
    301 		return fmt.Errorf("add ResolveEndpointV2: %w", err)
    302 	}
    303 	if err := stack.Finalize.Insert(&signRequestMiddleware{}, "ResolveEndpointV2", middleware.After); err != nil {
    304 		return fmt.Errorf("add Signing: %w", err)
    305 	}
    306 	return nil
    307 }