middleware_content_length.go (2703B)
1 package http 2 3 import ( 4 "context" 5 "fmt" 6 7 "github.com/aws/smithy-go/middleware" 8 ) 9 10 // ComputeContentLength provides a middleware to set the content-length 11 // header for the length of a serialize request body. 12 type ComputeContentLength struct { 13 } 14 15 // AddComputeContentLengthMiddleware adds ComputeContentLength to the middleware 16 // stack's Build step. 17 func AddComputeContentLengthMiddleware(stack *middleware.Stack) error { 18 return stack.Build.Add(&ComputeContentLength{}, middleware.After) 19 } 20 21 // ID returns the identifier for the ComputeContentLength. 22 func (m *ComputeContentLength) ID() string { return "ComputeContentLength" } 23 24 // HandleBuild adds the length of the serialized request to the HTTP header 25 // if the length can be determined. 26 func (m *ComputeContentLength) HandleBuild( 27 ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler, 28 ) ( 29 out middleware.BuildOutput, metadata middleware.Metadata, err error, 30 ) { 31 req, ok := in.Request.(*Request) 32 if !ok { 33 return out, metadata, fmt.Errorf("unknown request type %T", req) 34 } 35 36 // do nothing if request content-length was set to 0 or above. 37 if req.ContentLength >= 0 { 38 return next.HandleBuild(ctx, in) 39 } 40 41 // attempt to compute stream length 42 if n, ok, err := req.StreamLength(); err != nil { 43 return out, metadata, fmt.Errorf( 44 "failed getting length of request stream, %w", err) 45 } else if ok { 46 req.ContentLength = n 47 } 48 49 return next.HandleBuild(ctx, in) 50 } 51 52 // validateContentLength provides a middleware to validate the content-length 53 // is valid (greater than zero), for the serialized request payload. 54 type validateContentLength struct{} 55 56 // ValidateContentLengthHeader adds middleware that validates request content-length 57 // is set to value greater than zero. 58 func ValidateContentLengthHeader(stack *middleware.Stack) error { 59 return stack.Build.Add(&validateContentLength{}, middleware.After) 60 } 61 62 // ID returns the identifier for the ComputeContentLength. 63 func (m *validateContentLength) ID() string { return "ValidateContentLength" } 64 65 // HandleBuild adds the length of the serialized request to the HTTP header 66 // if the length can be determined. 67 func (m *validateContentLength) HandleBuild( 68 ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler, 69 ) ( 70 out middleware.BuildOutput, metadata middleware.Metadata, err error, 71 ) { 72 req, ok := in.Request.(*Request) 73 if !ok { 74 return out, metadata, fmt.Errorf("unknown request type %T", req) 75 } 76 77 // if request content-length was set to less than 0, return an error 78 if req.ContentLength < 0 { 79 return out, metadata, fmt.Errorf( 80 "content length for payload is required and must be at least 0") 81 } 82 83 return next.HandleBuild(ctx, in) 84 }