request_middleware.go (8734B)
1 package imds 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io/ioutil" 8 "net/url" 9 "path" 10 "time" 11 12 awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" 13 "github.com/aws/aws-sdk-go-v2/aws/retry" 14 "github.com/aws/smithy-go/middleware" 15 smithyhttp "github.com/aws/smithy-go/transport/http" 16 ) 17 18 func addAPIRequestMiddleware(stack *middleware.Stack, 19 options Options, 20 operation string, 21 getPath func(interface{}) (string, error), 22 getOutput func(*smithyhttp.Response) (interface{}, error), 23 ) (err error) { 24 err = addRequestMiddleware(stack, options, "GET", operation, getPath, getOutput) 25 if err != nil { 26 return err 27 } 28 29 // Token Serializer build and state management. 30 if !options.disableAPIToken { 31 err = stack.Finalize.Insert(options.tokenProvider, (*retry.Attempt)(nil).ID(), middleware.After) 32 if err != nil { 33 return err 34 } 35 36 err = stack.Deserialize.Insert(options.tokenProvider, "OperationDeserializer", middleware.Before) 37 if err != nil { 38 return err 39 } 40 } 41 42 return nil 43 } 44 45 func addRequestMiddleware(stack *middleware.Stack, 46 options Options, 47 method string, 48 operation string, 49 getPath func(interface{}) (string, error), 50 getOutput func(*smithyhttp.Response) (interface{}, error), 51 ) (err error) { 52 err = awsmiddleware.AddSDKAgentKey(awsmiddleware.FeatureMetadata, "ec2-imds")(stack) 53 if err != nil { 54 return err 55 } 56 57 // Operation timeout 58 err = stack.Initialize.Add(&operationTimeout{ 59 Disabled: options.DisableDefaultTimeout, 60 DefaultTimeout: defaultOperationTimeout, 61 }, middleware.Before) 62 if err != nil { 63 return err 64 } 65 66 // Operation Serializer 67 err = stack.Serialize.Add(&serializeRequest{ 68 GetPath: getPath, 69 Method: method, 70 }, middleware.After) 71 if err != nil { 72 return err 73 } 74 75 // Operation endpoint resolver 76 err = stack.Serialize.Insert(&resolveEndpoint{ 77 Endpoint: options.Endpoint, 78 EndpointMode: options.EndpointMode, 79 }, "OperationSerializer", middleware.Before) 80 if err != nil { 81 return err 82 } 83 84 // Operation Deserializer 85 err = stack.Deserialize.Add(&deserializeResponse{ 86 GetOutput: getOutput, 87 }, middleware.After) 88 if err != nil { 89 return err 90 } 91 92 err = stack.Deserialize.Add(&smithyhttp.RequestResponseLogger{ 93 LogRequest: options.ClientLogMode.IsRequest(), 94 LogRequestWithBody: options.ClientLogMode.IsRequestWithBody(), 95 LogResponse: options.ClientLogMode.IsResponse(), 96 LogResponseWithBody: options.ClientLogMode.IsResponseWithBody(), 97 }, middleware.After) 98 if err != nil { 99 return err 100 } 101 102 err = addSetLoggerMiddleware(stack, options) 103 if err != nil { 104 return err 105 } 106 107 if err := addProtocolFinalizerMiddlewares(stack, options, operation); err != nil { 108 return fmt.Errorf("add protocol finalizers: %w", err) 109 } 110 111 // Retry support 112 return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{ 113 Retryer: options.Retryer, 114 LogRetryAttempts: options.ClientLogMode.IsRetries(), 115 }) 116 } 117 118 func addSetLoggerMiddleware(stack *middleware.Stack, o Options) error { 119 return middleware.AddSetLoggerMiddleware(stack, o.Logger) 120 } 121 122 type serializeRequest struct { 123 GetPath func(interface{}) (string, error) 124 Method string 125 } 126 127 func (*serializeRequest) ID() string { 128 return "OperationSerializer" 129 } 130 131 func (m *serializeRequest) HandleSerialize( 132 ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler, 133 ) ( 134 out middleware.SerializeOutput, metadata middleware.Metadata, err error, 135 ) { 136 request, ok := in.Request.(*smithyhttp.Request) 137 if !ok { 138 return out, metadata, fmt.Errorf("unknown transport type %T", in.Request) 139 } 140 141 reqPath, err := m.GetPath(in.Parameters) 142 if err != nil { 143 return out, metadata, fmt.Errorf("unable to get request URL path, %w", err) 144 } 145 146 request.Request.URL.Path = reqPath 147 request.Request.Method = m.Method 148 149 return next.HandleSerialize(ctx, in) 150 } 151 152 type deserializeResponse struct { 153 GetOutput func(*smithyhttp.Response) (interface{}, error) 154 } 155 156 func (*deserializeResponse) ID() string { 157 return "OperationDeserializer" 158 } 159 160 func (m *deserializeResponse) HandleDeserialize( 161 ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler, 162 ) ( 163 out middleware.DeserializeOutput, metadata middleware.Metadata, err error, 164 ) { 165 out, metadata, err = next.HandleDeserialize(ctx, in) 166 if err != nil { 167 return out, metadata, err 168 } 169 170 resp, ok := out.RawResponse.(*smithyhttp.Response) 171 if !ok { 172 return out, metadata, fmt.Errorf( 173 "unexpected transport response type, %T, want %T", out.RawResponse, resp) 174 } 175 defer resp.Body.Close() 176 177 // read the full body so that any operation timeouts cleanup will not race 178 // the body being read. 179 body, err := ioutil.ReadAll(resp.Body) 180 if err != nil { 181 return out, metadata, fmt.Errorf("read response body failed, %w", err) 182 } 183 resp.Body = ioutil.NopCloser(bytes.NewReader(body)) 184 185 // Anything that's not 200 |< 300 is error 186 if resp.StatusCode < 200 || resp.StatusCode >= 300 { 187 return out, metadata, &smithyhttp.ResponseError{ 188 Response: resp, 189 Err: fmt.Errorf("request to EC2 IMDS failed"), 190 } 191 } 192 193 result, err := m.GetOutput(resp) 194 if err != nil { 195 return out, metadata, fmt.Errorf( 196 "unable to get deserialized result for response, %w", err, 197 ) 198 } 199 out.Result = result 200 201 return out, metadata, err 202 } 203 204 type resolveEndpoint struct { 205 Endpoint string 206 EndpointMode EndpointModeState 207 } 208 209 func (*resolveEndpoint) ID() string { 210 return "ResolveEndpoint" 211 } 212 213 func (m *resolveEndpoint) HandleSerialize( 214 ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler, 215 ) ( 216 out middleware.SerializeOutput, metadata middleware.Metadata, err error, 217 ) { 218 219 req, ok := in.Request.(*smithyhttp.Request) 220 if !ok { 221 return out, metadata, fmt.Errorf("unknown transport type %T", in.Request) 222 } 223 224 var endpoint string 225 if len(m.Endpoint) > 0 { 226 endpoint = m.Endpoint 227 } else { 228 switch m.EndpointMode { 229 case EndpointModeStateIPv6: 230 endpoint = defaultIPv6Endpoint 231 case EndpointModeStateIPv4: 232 fallthrough 233 case EndpointModeStateUnset: 234 endpoint = defaultIPv4Endpoint 235 default: 236 return out, metadata, fmt.Errorf("unsupported IMDS endpoint mode") 237 } 238 } 239 240 req.URL, err = url.Parse(endpoint) 241 if err != nil { 242 return out, metadata, fmt.Errorf("failed to parse endpoint URL: %w", err) 243 } 244 245 return next.HandleSerialize(ctx, in) 246 } 247 248 const ( 249 defaultOperationTimeout = 5 * time.Second 250 ) 251 252 // operationTimeout adds a timeout on the middleware stack if the Context the 253 // stack was called with does not have a deadline. The next middleware must 254 // complete before the timeout, or the context will be canceled. 255 // 256 // If DefaultTimeout is zero, no default timeout will be used if the Context 257 // does not have a timeout. 258 // 259 // The next middleware must also ensure that any resources that are also 260 // canceled by the stack's context are completely consumed before returning. 261 // Otherwise the timeout cleanup will race the resource being consumed 262 // upstream. 263 type operationTimeout struct { 264 Disabled bool 265 DefaultTimeout time.Duration 266 } 267 268 func (*operationTimeout) ID() string { return "OperationTimeout" } 269 270 func (m *operationTimeout) HandleInitialize( 271 ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler, 272 ) ( 273 output middleware.InitializeOutput, metadata middleware.Metadata, err error, 274 ) { 275 if m.Disabled { 276 return next.HandleInitialize(ctx, input) 277 } 278 279 if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 { 280 var cancelFn func() 281 ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout) 282 defer cancelFn() 283 } 284 285 return next.HandleInitialize(ctx, input) 286 } 287 288 // appendURIPath joins a URI path component to the existing path with `/` 289 // separators between the path components. If the path being added ends with a 290 // trailing `/` that slash will be maintained. 291 func appendURIPath(base, add string) string { 292 reqPath := path.Join(base, add) 293 if len(add) != 0 && add[len(add)-1] == '/' { 294 reqPath += "/" 295 } 296 return reqPath 297 } 298 299 func addProtocolFinalizerMiddlewares(stack *middleware.Stack, options Options, operation string) error { 300 if err := stack.Finalize.Add(&resolveAuthSchemeMiddleware{operation: operation, options: options}, middleware.Before); err != nil { 301 return fmt.Errorf("add ResolveAuthScheme: %w", err) 302 } 303 if err := stack.Finalize.Insert(&getIdentityMiddleware{options: options}, "ResolveAuthScheme", middleware.After); err != nil { 304 return fmt.Errorf("add GetIdentity: %w", err) 305 } 306 if err := stack.Finalize.Insert(&resolveEndpointV2Middleware{options: options}, "GetIdentity", middleware.After); err != nil { 307 return fmt.Errorf("add ResolveEndpointV2: %w", err) 308 } 309 if err := stack.Finalize.Insert(&signRequestMiddleware{}, "ResolveEndpointV2", middleware.After); err != nil { 310 return fmt.Errorf("add Signing: %w", err) 311 } 312 return nil 313 }