code.dwrz.net

Go monorepo.
Log | Files | Refs

request_middleware.go (6990B)


      1 package imds
      2 
      3 import (
      4 	"bytes"
      5 	"context"
      6 	"fmt"
      7 	"io/ioutil"
      8 	"net/url"
      9 	"path"
     10 	"time"
     11 
     12 	awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
     13 	"github.com/aws/aws-sdk-go-v2/aws/retry"
     14 	"github.com/aws/smithy-go/middleware"
     15 	smithyhttp "github.com/aws/smithy-go/transport/http"
     16 )
     17 
     18 func addAPIRequestMiddleware(stack *middleware.Stack,
     19 	options Options,
     20 	getPath func(interface{}) (string, error),
     21 	getOutput func(*smithyhttp.Response) (interface{}, error),
     22 ) (err error) {
     23 	err = addRequestMiddleware(stack, options, "GET", getPath, getOutput)
     24 	if err != nil {
     25 		return err
     26 	}
     27 
     28 	// Token Serializer build and state management.
     29 	if !options.disableAPIToken {
     30 		err = stack.Finalize.Insert(options.tokenProvider, (*retry.Attempt)(nil).ID(), middleware.After)
     31 		if err != nil {
     32 			return err
     33 		}
     34 
     35 		err = stack.Deserialize.Insert(options.tokenProvider, "OperationDeserializer", middleware.Before)
     36 		if err != nil {
     37 			return err
     38 		}
     39 	}
     40 
     41 	return nil
     42 }
     43 
     44 func addRequestMiddleware(stack *middleware.Stack,
     45 	options Options,
     46 	method string,
     47 	getPath func(interface{}) (string, error),
     48 	getOutput func(*smithyhttp.Response) (interface{}, error),
     49 ) (err error) {
     50 	err = awsmiddleware.AddSDKAgentKey(awsmiddleware.FeatureMetadata, "ec2-imds")(stack)
     51 	if err != nil {
     52 		return err
     53 	}
     54 
     55 	// Operation timeout
     56 	err = stack.Initialize.Add(&operationTimeout{
     57 		DefaultTimeout: defaultOperationTimeout,
     58 	}, middleware.Before)
     59 	if err != nil {
     60 		return err
     61 	}
     62 
     63 	// Operation Serializer
     64 	err = stack.Serialize.Add(&serializeRequest{
     65 		GetPath: getPath,
     66 		Method:  method,
     67 	}, middleware.After)
     68 	if err != nil {
     69 		return err
     70 	}
     71 
     72 	// Operation endpoint resolver
     73 	err = stack.Serialize.Insert(&resolveEndpoint{
     74 		Endpoint:     options.Endpoint,
     75 		EndpointMode: options.EndpointMode,
     76 	}, "OperationSerializer", middleware.Before)
     77 	if err != nil {
     78 		return err
     79 	}
     80 
     81 	// Operation Deserializer
     82 	err = stack.Deserialize.Add(&deserializeResponse{
     83 		GetOutput: getOutput,
     84 	}, middleware.After)
     85 	if err != nil {
     86 		return err
     87 	}
     88 
     89 	// Retry support
     90 	return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{
     91 		Retryer:          options.Retryer,
     92 		LogRetryAttempts: options.ClientLogMode.IsRetries(),
     93 	})
     94 }
     95 
     96 type serializeRequest struct {
     97 	GetPath func(interface{}) (string, error)
     98 	Method  string
     99 }
    100 
    101 func (*serializeRequest) ID() string {
    102 	return "OperationSerializer"
    103 }
    104 
    105 func (m *serializeRequest) HandleSerialize(
    106 	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
    107 ) (
    108 	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
    109 ) {
    110 	request, ok := in.Request.(*smithyhttp.Request)
    111 	if !ok {
    112 		return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
    113 	}
    114 
    115 	reqPath, err := m.GetPath(in.Parameters)
    116 	if err != nil {
    117 		return out, metadata, fmt.Errorf("unable to get request URL path, %w", err)
    118 	}
    119 
    120 	request.Request.URL.Path = reqPath
    121 	request.Request.Method = m.Method
    122 
    123 	return next.HandleSerialize(ctx, in)
    124 }
    125 
    126 type deserializeResponse struct {
    127 	GetOutput func(*smithyhttp.Response) (interface{}, error)
    128 }
    129 
    130 func (*deserializeResponse) ID() string {
    131 	return "OperationDeserializer"
    132 }
    133 
    134 func (m *deserializeResponse) HandleDeserialize(
    135 	ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
    136 ) (
    137 	out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
    138 ) {
    139 	out, metadata, err = next.HandleDeserialize(ctx, in)
    140 	if err != nil {
    141 		return out, metadata, err
    142 	}
    143 
    144 	resp, ok := out.RawResponse.(*smithyhttp.Response)
    145 	if !ok {
    146 		return out, metadata, fmt.Errorf(
    147 			"unexpected transport response type, %T, want %T", out.RawResponse, resp)
    148 	}
    149 	defer resp.Body.Close()
    150 
    151 	// read the full body so that any operation timeouts cleanup will not race
    152 	// the body being read.
    153 	body, err := ioutil.ReadAll(resp.Body)
    154 	if err != nil {
    155 		return out, metadata, fmt.Errorf("read response body failed, %w", err)
    156 	}
    157 	resp.Body = ioutil.NopCloser(bytes.NewReader(body))
    158 
    159 	// Anything that's not 200 |< 300 is error
    160 	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
    161 		return out, metadata, &smithyhttp.ResponseError{
    162 			Response: resp,
    163 			Err:      fmt.Errorf("request to EC2 IMDS failed"),
    164 		}
    165 	}
    166 
    167 	result, err := m.GetOutput(resp)
    168 	if err != nil {
    169 		return out, metadata, fmt.Errorf(
    170 			"unable to get deserialized result for response, %w", err,
    171 		)
    172 	}
    173 	out.Result = result
    174 
    175 	return out, metadata, err
    176 }
    177 
    178 type resolveEndpoint struct {
    179 	Endpoint     string
    180 	EndpointMode EndpointModeState
    181 }
    182 
    183 func (*resolveEndpoint) ID() string {
    184 	return "ResolveEndpoint"
    185 }
    186 
    187 func (m *resolveEndpoint) HandleSerialize(
    188 	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
    189 ) (
    190 	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
    191 ) {
    192 
    193 	req, ok := in.Request.(*smithyhttp.Request)
    194 	if !ok {
    195 		return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
    196 	}
    197 
    198 	var endpoint string
    199 	if len(m.Endpoint) > 0 {
    200 		endpoint = m.Endpoint
    201 	} else {
    202 		switch m.EndpointMode {
    203 		case EndpointModeStateIPv6:
    204 			endpoint = defaultIPv6Endpoint
    205 		case EndpointModeStateIPv4:
    206 			fallthrough
    207 		case EndpointModeStateUnset:
    208 			endpoint = defaultIPv4Endpoint
    209 		default:
    210 			return out, metadata, fmt.Errorf("unsupported IMDS endpoint mode")
    211 		}
    212 	}
    213 
    214 	req.URL, err = url.Parse(endpoint)
    215 	if err != nil {
    216 		return out, metadata, fmt.Errorf("failed to parse endpoint URL: %w", err)
    217 	}
    218 
    219 	return next.HandleSerialize(ctx, in)
    220 }
    221 
    222 const (
    223 	defaultOperationTimeout = 5 * time.Second
    224 )
    225 
    226 // operationTimeout adds a timeout on the middleware stack if the Context the
    227 // stack was called with does not have a deadline. The next middleware must
    228 // complete before the timeout, or the context will be canceled.
    229 //
    230 // If DefaultTimeout is zero, no default timeout will be used if the Context
    231 // does not have a timeout.
    232 //
    233 // The next middleware must also ensure that any resources that are also
    234 // canceled by the stack's context are completely consumed before returning.
    235 // Otherwise the timeout cleanup will race the resource being consumed
    236 // upstream.
    237 type operationTimeout struct {
    238 	DefaultTimeout time.Duration
    239 }
    240 
    241 func (*operationTimeout) ID() string { return "OperationTimeout" }
    242 
    243 func (m *operationTimeout) HandleInitialize(
    244 	ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler,
    245 ) (
    246 	output middleware.InitializeOutput, metadata middleware.Metadata, err error,
    247 ) {
    248 	if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 {
    249 		var cancelFn func()
    250 		ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout)
    251 		defer cancelFn()
    252 	}
    253 
    254 	return next.HandleInitialize(ctx, input)
    255 }
    256 
    257 // appendURIPath joins a URI path component to the existing path with `/`
    258 // separators between the path components. If the path being added ends with a
    259 // trailing `/` that slash will be maintained.
    260 func appendURIPath(base, add string) string {
    261 	reqPath := path.Join(base, add)
    262 	if len(add) != 0 && add[len(add)-1] == '/' {
    263 		reqPath += "/"
    264 	}
    265 	return reqPath
    266 }