middleware_capture_request_compression.go (1657B)
1 package requestcompression 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "github.com/aws/smithy-go/middleware" 8 smithyhttp "github.com/aws/smithy-go/transport/http" 9 "io" 10 "net/http" 11 ) 12 13 const captureUncompressedRequestID = "CaptureUncompressedRequest" 14 15 // AddCaptureUncompressedRequestMiddleware captures http request before compress encoding for check 16 func AddCaptureUncompressedRequestMiddleware(stack *middleware.Stack, buf *bytes.Buffer) error { 17 return stack.Serialize.Insert(&captureUncompressedRequestMiddleware{ 18 buf: buf, 19 }, "RequestCompression", middleware.Before) 20 } 21 22 type captureUncompressedRequestMiddleware struct { 23 req *http.Request 24 buf *bytes.Buffer 25 bytes []byte 26 } 27 28 // ID returns id of the captureUncompressedRequestMiddleware 29 func (*captureUncompressedRequestMiddleware) ID() string { 30 return captureUncompressedRequestID 31 } 32 33 // HandleSerialize captures request payload before it is compressed by request compression middleware 34 func (m *captureUncompressedRequestMiddleware) HandleSerialize(ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler, 35 ) ( 36 output middleware.SerializeOutput, metadata middleware.Metadata, err error, 37 ) { 38 request, ok := input.Request.(*smithyhttp.Request) 39 if !ok { 40 return output, metadata, fmt.Errorf("error when retrieving http request") 41 } 42 43 _, err = io.Copy(m.buf, request.GetStream()) 44 if err != nil { 45 return output, metadata, fmt.Errorf("error when copying http request stream: %q", err) 46 } 47 if err = request.RewindStream(); err != nil { 48 return output, metadata, fmt.Errorf("error when rewinding request stream: %q", err) 49 } 50 51 return next.HandleSerialize(ctx, input) 52 }