request_middleware.go (8600B)
1 package imds 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io/ioutil" 8 "net/url" 9 "path" 10 "time" 11 12 awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" 13 "github.com/aws/aws-sdk-go-v2/aws/retry" 14 "github.com/aws/smithy-go/middleware" 15 smithyhttp "github.com/aws/smithy-go/transport/http" 16 ) 17 18 func addAPIRequestMiddleware(stack *middleware.Stack, 19 options Options, 20 operation string, 21 getPath func(interface{}) (string, error), 22 getOutput func(*smithyhttp.Response) (interface{}, error), 23 ) (err error) { 24 err = addRequestMiddleware(stack, options, "GET", operation, getPath, getOutput) 25 if err != nil { 26 return err 27 } 28 29 // Token Serializer build and state management. 30 if !options.disableAPIToken { 31 err = stack.Finalize.Insert(options.tokenProvider, (*retry.Attempt)(nil).ID(), middleware.After) 32 if err != nil { 33 return err 34 } 35 36 err = stack.Deserialize.Insert(options.tokenProvider, "OperationDeserializer", middleware.Before) 37 if err != nil { 38 return err 39 } 40 } 41 42 return nil 43 } 44 45 func addRequestMiddleware(stack *middleware.Stack, 46 options Options, 47 method string, 48 operation string, 49 getPath func(interface{}) (string, error), 50 getOutput func(*smithyhttp.Response) (interface{}, error), 51 ) (err error) { 52 err = awsmiddleware.AddSDKAgentKey(awsmiddleware.FeatureMetadata, "ec2-imds")(stack) 53 if err != nil { 54 return err 55 } 56 57 // Operation timeout 58 err = stack.Initialize.Add(&operationTimeout{ 59 DefaultTimeout: defaultOperationTimeout, 60 }, middleware.Before) 61 if err != nil { 62 return err 63 } 64 65 // Operation Serializer 66 err = stack.Serialize.Add(&serializeRequest{ 67 GetPath: getPath, 68 Method: method, 69 }, middleware.After) 70 if err != nil { 71 return err 72 } 73 74 // Operation endpoint resolver 75 err = stack.Serialize.Insert(&resolveEndpoint{ 76 Endpoint: options.Endpoint, 77 EndpointMode: options.EndpointMode, 78 }, "OperationSerializer", middleware.Before) 79 if err != nil { 80 return err 81 } 82 83 // Operation Deserializer 84 err = stack.Deserialize.Add(&deserializeResponse{ 85 GetOutput: getOutput, 86 }, middleware.After) 87 if err != nil { 88 return err 89 } 90 91 err = stack.Deserialize.Add(&smithyhttp.RequestResponseLogger{ 92 LogRequest: options.ClientLogMode.IsRequest(), 93 LogRequestWithBody: options.ClientLogMode.IsRequestWithBody(), 94 LogResponse: options.ClientLogMode.IsResponse(), 95 LogResponseWithBody: options.ClientLogMode.IsResponseWithBody(), 96 }, middleware.After) 97 if err != nil { 98 return err 99 } 100 101 err = addSetLoggerMiddleware(stack, options) 102 if err != nil { 103 return err 104 } 105 106 if err := addProtocolFinalizerMiddlewares(stack, options, operation); err != nil { 107 return fmt.Errorf("add protocol finalizers: %w", err) 108 } 109 110 // Retry support 111 return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{ 112 Retryer: options.Retryer, 113 LogRetryAttempts: options.ClientLogMode.IsRetries(), 114 }) 115 } 116 117 func addSetLoggerMiddleware(stack *middleware.Stack, o Options) error { 118 return middleware.AddSetLoggerMiddleware(stack, o.Logger) 119 } 120 121 type serializeRequest struct { 122 GetPath func(interface{}) (string, error) 123 Method string 124 } 125 126 func (*serializeRequest) ID() string { 127 return "OperationSerializer" 128 } 129 130 func (m *serializeRequest) HandleSerialize( 131 ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler, 132 ) ( 133 out middleware.SerializeOutput, metadata middleware.Metadata, err error, 134 ) { 135 request, ok := in.Request.(*smithyhttp.Request) 136 if !ok { 137 return out, metadata, fmt.Errorf("unknown transport type %T", in.Request) 138 } 139 140 reqPath, err := m.GetPath(in.Parameters) 141 if err != nil { 142 return out, metadata, fmt.Errorf("unable to get request URL path, %w", err) 143 } 144 145 request.Request.URL.Path = reqPath 146 request.Request.Method = m.Method 147 148 return next.HandleSerialize(ctx, in) 149 } 150 151 type deserializeResponse struct { 152 GetOutput func(*smithyhttp.Response) (interface{}, error) 153 } 154 155 func (*deserializeResponse) ID() string { 156 return "OperationDeserializer" 157 } 158 159 func (m *deserializeResponse) HandleDeserialize( 160 ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler, 161 ) ( 162 out middleware.DeserializeOutput, metadata middleware.Metadata, err error, 163 ) { 164 out, metadata, err = next.HandleDeserialize(ctx, in) 165 if err != nil { 166 return out, metadata, err 167 } 168 169 resp, ok := out.RawResponse.(*smithyhttp.Response) 170 if !ok { 171 return out, metadata, fmt.Errorf( 172 "unexpected transport response type, %T, want %T", out.RawResponse, resp) 173 } 174 defer resp.Body.Close() 175 176 // read the full body so that any operation timeouts cleanup will not race 177 // the body being read. 178 body, err := ioutil.ReadAll(resp.Body) 179 if err != nil { 180 return out, metadata, fmt.Errorf("read response body failed, %w", err) 181 } 182 resp.Body = ioutil.NopCloser(bytes.NewReader(body)) 183 184 // Anything that's not 200 |< 300 is error 185 if resp.StatusCode < 200 || resp.StatusCode >= 300 { 186 return out, metadata, &smithyhttp.ResponseError{ 187 Response: resp, 188 Err: fmt.Errorf("request to EC2 IMDS failed"), 189 } 190 } 191 192 result, err := m.GetOutput(resp) 193 if err != nil { 194 return out, metadata, fmt.Errorf( 195 "unable to get deserialized result for response, %w", err, 196 ) 197 } 198 out.Result = result 199 200 return out, metadata, err 201 } 202 203 type resolveEndpoint struct { 204 Endpoint string 205 EndpointMode EndpointModeState 206 } 207 208 func (*resolveEndpoint) ID() string { 209 return "ResolveEndpoint" 210 } 211 212 func (m *resolveEndpoint) HandleSerialize( 213 ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler, 214 ) ( 215 out middleware.SerializeOutput, metadata middleware.Metadata, err error, 216 ) { 217 218 req, ok := in.Request.(*smithyhttp.Request) 219 if !ok { 220 return out, metadata, fmt.Errorf("unknown transport type %T", in.Request) 221 } 222 223 var endpoint string 224 if len(m.Endpoint) > 0 { 225 endpoint = m.Endpoint 226 } else { 227 switch m.EndpointMode { 228 case EndpointModeStateIPv6: 229 endpoint = defaultIPv6Endpoint 230 case EndpointModeStateIPv4: 231 fallthrough 232 case EndpointModeStateUnset: 233 endpoint = defaultIPv4Endpoint 234 default: 235 return out, metadata, fmt.Errorf("unsupported IMDS endpoint mode") 236 } 237 } 238 239 req.URL, err = url.Parse(endpoint) 240 if err != nil { 241 return out, metadata, fmt.Errorf("failed to parse endpoint URL: %w", err) 242 } 243 244 return next.HandleSerialize(ctx, in) 245 } 246 247 const ( 248 defaultOperationTimeout = 5 * time.Second 249 ) 250 251 // operationTimeout adds a timeout on the middleware stack if the Context the 252 // stack was called with does not have a deadline. The next middleware must 253 // complete before the timeout, or the context will be canceled. 254 // 255 // If DefaultTimeout is zero, no default timeout will be used if the Context 256 // does not have a timeout. 257 // 258 // The next middleware must also ensure that any resources that are also 259 // canceled by the stack's context are completely consumed before returning. 260 // Otherwise the timeout cleanup will race the resource being consumed 261 // upstream. 262 type operationTimeout struct { 263 DefaultTimeout time.Duration 264 } 265 266 func (*operationTimeout) ID() string { return "OperationTimeout" } 267 268 func (m *operationTimeout) HandleInitialize( 269 ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler, 270 ) ( 271 output middleware.InitializeOutput, metadata middleware.Metadata, err error, 272 ) { 273 if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 { 274 var cancelFn func() 275 ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout) 276 defer cancelFn() 277 } 278 279 return next.HandleInitialize(ctx, input) 280 } 281 282 // appendURIPath joins a URI path component to the existing path with `/` 283 // separators between the path components. If the path being added ends with a 284 // trailing `/` that slash will be maintained. 285 func appendURIPath(base, add string) string { 286 reqPath := path.Join(base, add) 287 if len(add) != 0 && add[len(add)-1] == '/' { 288 reqPath += "/" 289 } 290 return reqPath 291 } 292 293 func addProtocolFinalizerMiddlewares(stack *middleware.Stack, options Options, operation string) error { 294 if err := stack.Finalize.Add(&resolveAuthSchemeMiddleware{operation: operation, options: options}, middleware.Before); err != nil { 295 return fmt.Errorf("add ResolveAuthScheme: %w", err) 296 } 297 if err := stack.Finalize.Insert(&getIdentityMiddleware{options: options}, "ResolveAuthScheme", middleware.After); err != nil { 298 return fmt.Errorf("add GetIdentity: %w", err) 299 } 300 if err := stack.Finalize.Insert(&resolveEndpointV2Middleware{options: options}, "GetIdentity", middleware.After); err != nil { 301 return fmt.Errorf("add ResolveEndpointV2: %w", err) 302 } 303 if err := stack.Finalize.Insert(&signRequestMiddleware{}, "ResolveEndpointV2", middleware.After); err != nil { 304 return fmt.Errorf("add Signing: %w", err) 305 } 306 return nil 307 }