code.dwrz.net

Go monorepo.
Log | Files | Refs

middleware_content_length.go (2703B)


      1 package http
      2 
      3 import (
      4 	"context"
      5 	"fmt"
      6 
      7 	"github.com/aws/smithy-go/middleware"
      8 )
      9 
     10 // ComputeContentLength provides a middleware to set the content-length
     11 // header for the length of a serialize request body.
     12 type ComputeContentLength struct {
     13 }
     14 
     15 // AddComputeContentLengthMiddleware adds ComputeContentLength to the middleware
     16 // stack's Build step.
     17 func AddComputeContentLengthMiddleware(stack *middleware.Stack) error {
     18 	return stack.Build.Add(&ComputeContentLength{}, middleware.After)
     19 }
     20 
     21 // ID returns the identifier for the ComputeContentLength.
     22 func (m *ComputeContentLength) ID() string { return "ComputeContentLength" }
     23 
     24 // HandleBuild adds the length of the serialized request to the HTTP header
     25 // if the length can be determined.
     26 func (m *ComputeContentLength) HandleBuild(
     27 	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
     28 ) (
     29 	out middleware.BuildOutput, metadata middleware.Metadata, err error,
     30 ) {
     31 	req, ok := in.Request.(*Request)
     32 	if !ok {
     33 		return out, metadata, fmt.Errorf("unknown request type %T", req)
     34 	}
     35 
     36 	// do nothing if request content-length was set to 0 or above.
     37 	if req.ContentLength >= 0 {
     38 		return next.HandleBuild(ctx, in)
     39 	}
     40 
     41 	// attempt to compute stream length
     42 	if n, ok, err := req.StreamLength(); err != nil {
     43 		return out, metadata, fmt.Errorf(
     44 			"failed getting length of request stream, %w", err)
     45 	} else if ok {
     46 		req.ContentLength = n
     47 	}
     48 
     49 	return next.HandleBuild(ctx, in)
     50 }
     51 
     52 // validateContentLength provides a middleware to validate the content-length
     53 // is valid (greater than zero), for the serialized request payload.
     54 type validateContentLength struct{}
     55 
     56 // ValidateContentLengthHeader adds middleware that validates request content-length
     57 // is set to value greater than zero.
     58 func ValidateContentLengthHeader(stack *middleware.Stack) error {
     59 	return stack.Build.Add(&validateContentLength{}, middleware.After)
     60 }
     61 
     62 // ID returns the identifier for the ComputeContentLength.
     63 func (m *validateContentLength) ID() string { return "ValidateContentLength" }
     64 
     65 // HandleBuild adds the length of the serialized request to the HTTP header
     66 // if the length can be determined.
     67 func (m *validateContentLength) HandleBuild(
     68 	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
     69 ) (
     70 	out middleware.BuildOutput, metadata middleware.Metadata, err error,
     71 ) {
     72 	req, ok := in.Request.(*Request)
     73 	if !ok {
     74 		return out, metadata, fmt.Errorf("unknown request type %T", req)
     75 	}
     76 
     77 	// if request content-length was set to less than 0, return an error
     78 	if req.ContentLength < 0 {
     79 		return out, metadata, fmt.Errorf(
     80 			"content length for payload is required and must be at least 0")
     81 	}
     82 
     83 	return next.HandleBuild(ctx, in)
     84 }