src

Go monorepo.
git clone git://code.dwrz.net/src
Log | Files | Refs

request_compression.go (3330B)


      1 // Package requestcompression implements runtime support for smithy-modeled
      2 // request compression.
      3 //
      4 // This package is designated as private and is intended for use only by the
      5 // smithy client runtime. The exported API therein is not considered stable and
      6 // is subject to breaking changes without notice.
      7 package requestcompression
      8 
      9 import (
     10 	"bytes"
     11 	"context"
     12 	"fmt"
     13 	"github.com/aws/smithy-go/middleware"
     14 	"github.com/aws/smithy-go/transport/http"
     15 	"io"
     16 )
     17 
     18 const MaxRequestMinCompressSizeBytes = 10485760
     19 
     20 // Enumeration values for supported compress Algorithms.
     21 const (
     22 	GZIP = "gzip"
     23 )
     24 
     25 type compressFunc func(io.Reader) ([]byte, error)
     26 
     27 var allowedAlgorithms = map[string]compressFunc{
     28 	GZIP: gzipCompress,
     29 }
     30 
     31 // AddRequestCompression add requestCompression middleware to op stack
     32 func AddRequestCompression(stack *middleware.Stack, disabled bool, minBytes int64, algorithms []string) error {
     33 	return stack.Serialize.Add(&requestCompression{
     34 		disableRequestCompression:   disabled,
     35 		requestMinCompressSizeBytes: minBytes,
     36 		compressAlgorithms:          algorithms,
     37 	}, middleware.After)
     38 }
     39 
     40 type requestCompression struct {
     41 	disableRequestCompression   bool
     42 	requestMinCompressSizeBytes int64
     43 	compressAlgorithms          []string
     44 }
     45 
     46 // ID returns the ID of the middleware
     47 func (m requestCompression) ID() string {
     48 	return "RequestCompression"
     49 }
     50 
     51 // HandleSerialize gzip compress the request's stream/body if enabled by config fields
     52 func (m requestCompression) HandleSerialize(
     53 	ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
     54 ) (
     55 	out middleware.SerializeOutput, metadata middleware.Metadata, err error,
     56 ) {
     57 	if m.disableRequestCompression {
     58 		return next.HandleSerialize(ctx, in)
     59 	}
     60 	// still need to check requestMinCompressSizeBytes in case it is out of range after service client config
     61 	if m.requestMinCompressSizeBytes < 0 || m.requestMinCompressSizeBytes > MaxRequestMinCompressSizeBytes {
     62 		return out, metadata, fmt.Errorf("invalid range for min request compression size bytes %d, must be within 0 and 10485760 inclusively", m.requestMinCompressSizeBytes)
     63 	}
     64 
     65 	req, ok := in.Request.(*http.Request)
     66 	if !ok {
     67 		return out, metadata, fmt.Errorf("unknown request type %T", req)
     68 	}
     69 
     70 	for _, algorithm := range m.compressAlgorithms {
     71 		compressFunc := allowedAlgorithms[algorithm]
     72 		if compressFunc != nil {
     73 			if stream := req.GetStream(); stream != nil {
     74 				size, found, err := req.StreamLength()
     75 				if err != nil {
     76 					return out, metadata, fmt.Errorf("error while finding request stream length, %v", err)
     77 				} else if !found || size < m.requestMinCompressSizeBytes {
     78 					return next.HandleSerialize(ctx, in)
     79 				}
     80 
     81 				compressedBytes, err := compressFunc(stream)
     82 				if err != nil {
     83 					return out, metadata, fmt.Errorf("failed to compress request stream, %v", err)
     84 				}
     85 
     86 				var newReq *http.Request
     87 				if newReq, err = req.SetStream(bytes.NewReader(compressedBytes)); err != nil {
     88 					return out, metadata, fmt.Errorf("failed to set request stream, %v", err)
     89 				}
     90 				*req = *newReq
     91 
     92 				if val := req.Header.Get("Content-Encoding"); val != "" {
     93 					req.Header.Set("Content-Encoding", fmt.Sprintf("%s, %s", val, algorithm))
     94 				} else {
     95 					req.Header.Set("Content-Encoding", algorithm)
     96 				}
     97 			}
     98 			break
     99 		}
    100 	}
    101 
    102 	return next.HandleSerialize(ctx, in)
    103 }