request_middleware.go (6990B)
1 package imds 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io/ioutil" 8 "net/url" 9 "path" 10 "time" 11 12 awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" 13 "github.com/aws/aws-sdk-go-v2/aws/retry" 14 "github.com/aws/smithy-go/middleware" 15 smithyhttp "github.com/aws/smithy-go/transport/http" 16 ) 17 18 func addAPIRequestMiddleware(stack *middleware.Stack, 19 options Options, 20 getPath func(interface{}) (string, error), 21 getOutput func(*smithyhttp.Response) (interface{}, error), 22 ) (err error) { 23 err = addRequestMiddleware(stack, options, "GET", getPath, getOutput) 24 if err != nil { 25 return err 26 } 27 28 // Token Serializer build and state management. 29 if !options.disableAPIToken { 30 err = stack.Finalize.Insert(options.tokenProvider, (*retry.Attempt)(nil).ID(), middleware.After) 31 if err != nil { 32 return err 33 } 34 35 err = stack.Deserialize.Insert(options.tokenProvider, "OperationDeserializer", middleware.Before) 36 if err != nil { 37 return err 38 } 39 } 40 41 return nil 42 } 43 44 func addRequestMiddleware(stack *middleware.Stack, 45 options Options, 46 method string, 47 getPath func(interface{}) (string, error), 48 getOutput func(*smithyhttp.Response) (interface{}, error), 49 ) (err error) { 50 err = awsmiddleware.AddSDKAgentKey(awsmiddleware.FeatureMetadata, "ec2-imds")(stack) 51 if err != nil { 52 return err 53 } 54 55 // Operation timeout 56 err = stack.Initialize.Add(&operationTimeout{ 57 DefaultTimeout: defaultOperationTimeout, 58 }, middleware.Before) 59 if err != nil { 60 return err 61 } 62 63 // Operation Serializer 64 err = stack.Serialize.Add(&serializeRequest{ 65 GetPath: getPath, 66 Method: method, 67 }, middleware.After) 68 if err != nil { 69 return err 70 } 71 72 // Operation endpoint resolver 73 err = stack.Serialize.Insert(&resolveEndpoint{ 74 Endpoint: options.Endpoint, 75 EndpointMode: options.EndpointMode, 76 }, "OperationSerializer", middleware.Before) 77 if err != nil { 78 return err 79 } 80 81 // Operation Deserializer 82 err = stack.Deserialize.Add(&deserializeResponse{ 83 GetOutput: getOutput, 84 }, middleware.After) 85 if err != nil { 86 return err 87 } 88 89 // Retry support 90 return retry.AddRetryMiddlewares(stack, retry.AddRetryMiddlewaresOptions{ 91 Retryer: options.Retryer, 92 LogRetryAttempts: options.ClientLogMode.IsRetries(), 93 }) 94 } 95 96 type serializeRequest struct { 97 GetPath func(interface{}) (string, error) 98 Method string 99 } 100 101 func (*serializeRequest) ID() string { 102 return "OperationSerializer" 103 } 104 105 func (m *serializeRequest) HandleSerialize( 106 ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler, 107 ) ( 108 out middleware.SerializeOutput, metadata middleware.Metadata, err error, 109 ) { 110 request, ok := in.Request.(*smithyhttp.Request) 111 if !ok { 112 return out, metadata, fmt.Errorf("unknown transport type %T", in.Request) 113 } 114 115 reqPath, err := m.GetPath(in.Parameters) 116 if err != nil { 117 return out, metadata, fmt.Errorf("unable to get request URL path, %w", err) 118 } 119 120 request.Request.URL.Path = reqPath 121 request.Request.Method = m.Method 122 123 return next.HandleSerialize(ctx, in) 124 } 125 126 type deserializeResponse struct { 127 GetOutput func(*smithyhttp.Response) (interface{}, error) 128 } 129 130 func (*deserializeResponse) ID() string { 131 return "OperationDeserializer" 132 } 133 134 func (m *deserializeResponse) HandleDeserialize( 135 ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler, 136 ) ( 137 out middleware.DeserializeOutput, metadata middleware.Metadata, err error, 138 ) { 139 out, metadata, err = next.HandleDeserialize(ctx, in) 140 if err != nil { 141 return out, metadata, err 142 } 143 144 resp, ok := out.RawResponse.(*smithyhttp.Response) 145 if !ok { 146 return out, metadata, fmt.Errorf( 147 "unexpected transport response type, %T, want %T", out.RawResponse, resp) 148 } 149 defer resp.Body.Close() 150 151 // read the full body so that any operation timeouts cleanup will not race 152 // the body being read. 153 body, err := ioutil.ReadAll(resp.Body) 154 if err != nil { 155 return out, metadata, fmt.Errorf("read response body failed, %w", err) 156 } 157 resp.Body = ioutil.NopCloser(bytes.NewReader(body)) 158 159 // Anything that's not 200 |< 300 is error 160 if resp.StatusCode < 200 || resp.StatusCode >= 300 { 161 return out, metadata, &smithyhttp.ResponseError{ 162 Response: resp, 163 Err: fmt.Errorf("request to EC2 IMDS failed"), 164 } 165 } 166 167 result, err := m.GetOutput(resp) 168 if err != nil { 169 return out, metadata, fmt.Errorf( 170 "unable to get deserialized result for response, %w", err, 171 ) 172 } 173 out.Result = result 174 175 return out, metadata, err 176 } 177 178 type resolveEndpoint struct { 179 Endpoint string 180 EndpointMode EndpointModeState 181 } 182 183 func (*resolveEndpoint) ID() string { 184 return "ResolveEndpoint" 185 } 186 187 func (m *resolveEndpoint) HandleSerialize( 188 ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler, 189 ) ( 190 out middleware.SerializeOutput, metadata middleware.Metadata, err error, 191 ) { 192 193 req, ok := in.Request.(*smithyhttp.Request) 194 if !ok { 195 return out, metadata, fmt.Errorf("unknown transport type %T", in.Request) 196 } 197 198 var endpoint string 199 if len(m.Endpoint) > 0 { 200 endpoint = m.Endpoint 201 } else { 202 switch m.EndpointMode { 203 case EndpointModeStateIPv6: 204 endpoint = defaultIPv6Endpoint 205 case EndpointModeStateIPv4: 206 fallthrough 207 case EndpointModeStateUnset: 208 endpoint = defaultIPv4Endpoint 209 default: 210 return out, metadata, fmt.Errorf("unsupported IMDS endpoint mode") 211 } 212 } 213 214 req.URL, err = url.Parse(endpoint) 215 if err != nil { 216 return out, metadata, fmt.Errorf("failed to parse endpoint URL: %w", err) 217 } 218 219 return next.HandleSerialize(ctx, in) 220 } 221 222 const ( 223 defaultOperationTimeout = 5 * time.Second 224 ) 225 226 // operationTimeout adds a timeout on the middleware stack if the Context the 227 // stack was called with does not have a deadline. The next middleware must 228 // complete before the timeout, or the context will be canceled. 229 // 230 // If DefaultTimeout is zero, no default timeout will be used if the Context 231 // does not have a timeout. 232 // 233 // The next middleware must also ensure that any resources that are also 234 // canceled by the stack's context are completely consumed before returning. 235 // Otherwise the timeout cleanup will race the resource being consumed 236 // upstream. 237 type operationTimeout struct { 238 DefaultTimeout time.Duration 239 } 240 241 func (*operationTimeout) ID() string { return "OperationTimeout" } 242 243 func (m *operationTimeout) HandleInitialize( 244 ctx context.Context, input middleware.InitializeInput, next middleware.InitializeHandler, 245 ) ( 246 output middleware.InitializeOutput, metadata middleware.Metadata, err error, 247 ) { 248 if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 { 249 var cancelFn func() 250 ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout) 251 defer cancelFn() 252 } 253 254 return next.HandleInitialize(ctx, input) 255 } 256 257 // appendURIPath joins a URI path component to the existing path with `/` 258 // separators between the path components. If the path being added ends with a 259 // trailing `/` that slash will be maintained. 260 func appendURIPath(base, add string) string { 261 reqPath := path.Join(base, add) 262 if len(add) != 0 && add[len(add)-1] == '/' { 263 reqPath += "/" 264 } 265 return reqPath 266 }