code.dwrz.net

Go monorepo.
Log | Files | Refs

middleware.go (13669B)


      1 package v4
      2 
      3 import (
      4 	"context"
      5 	"crypto/sha256"
      6 	"encoding/hex"
      7 	"fmt"
      8 	"io"
      9 	"net/http"
     10 	"strings"
     11 
     12 	"github.com/aws/aws-sdk-go-v2/aws"
     13 	awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
     14 	v4Internal "github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4"
     15 	"github.com/aws/aws-sdk-go-v2/internal/sdk"
     16 	"github.com/aws/smithy-go/middleware"
     17 	smithyhttp "github.com/aws/smithy-go/transport/http"
     18 )
     19 
     20 const computePayloadHashMiddlewareID = "ComputePayloadHash"
     21 
     22 // HashComputationError indicates an error occurred while computing the signing hash
     23 type HashComputationError struct {
     24 	Err error
     25 }
     26 
     27 // Error is the error message
     28 func (e *HashComputationError) Error() string {
     29 	return fmt.Sprintf("failed to compute payload hash: %v", e.Err)
     30 }
     31 
     32 // Unwrap returns the underlying error if one is set
     33 func (e *HashComputationError) Unwrap() error {
     34 	return e.Err
     35 }
     36 
     37 // SigningError indicates an error condition occurred while performing SigV4 signing
     38 type SigningError struct {
     39 	Err error
     40 }
     41 
     42 func (e *SigningError) Error() string {
     43 	return fmt.Sprintf("failed to sign request: %v", e.Err)
     44 }
     45 
     46 // Unwrap returns the underlying error cause
     47 func (e *SigningError) Unwrap() error {
     48 	return e.Err
     49 }
     50 
     51 // UseDynamicPayloadSigningMiddleware swaps the compute payload sha256 middleware with a resolver middleware that
     52 // switches between unsigned and signed payload based on TLS state for request.
     53 // This middleware should not be used for AWS APIs that do not support unsigned payload signing auth.
     54 // By default, SDK uses this middleware for known AWS APIs that support such TLS based auth selection .
     55 //
     56 // Usage example -
     57 // S3 PutObject API allows unsigned payload signing auth usage when TLS is enabled, and uses this middleware to
     58 // dynamically switch between unsigned and signed payload based on TLS state for request.
     59 func UseDynamicPayloadSigningMiddleware(stack *middleware.Stack) error {
     60 	_, err := stack.Build.Swap(computePayloadHashMiddlewareID, &dynamicPayloadSigningMiddleware{})
     61 	return err
     62 }
     63 
     64 // dynamicPayloadSigningMiddleware dynamically resolves the middleware that computes and set payload sha256 middleware.
     65 type dynamicPayloadSigningMiddleware struct {
     66 }
     67 
     68 // ID returns the resolver identifier
     69 func (m *dynamicPayloadSigningMiddleware) ID() string {
     70 	return computePayloadHashMiddlewareID
     71 }
     72 
     73 // HandleBuild sets a resolver that directs to the payload sha256 compute handler.
     74 func (m *dynamicPayloadSigningMiddleware) HandleBuild(
     75 	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
     76 ) (
     77 	out middleware.BuildOutput, metadata middleware.Metadata, err error,
     78 ) {
     79 	req, ok := in.Request.(*smithyhttp.Request)
     80 	if !ok {
     81 		return out, metadata, fmt.Errorf("unknown transport type %T", in.Request)
     82 	}
     83 
     84 	// if TLS is enabled, use unsigned payload when supported
     85 	if req.IsHTTPS() {
     86 		return (&unsignedPayload{}).HandleBuild(ctx, in, next)
     87 	}
     88 
     89 	// else fall back to signed payload
     90 	return (&computePayloadSHA256{}).HandleBuild(ctx, in, next)
     91 }
     92 
     93 // unsignedPayload sets the SigV4 request payload hash to unsigned.
     94 //
     95 // Will not set the Unsigned Payload magic SHA value, if a SHA has already been
     96 // stored in the context. (e.g. application pre-computed SHA256 before making
     97 // API call).
     98 //
     99 // This middleware does not check the X-Amz-Content-Sha256 header, if that
    100 // header is serialized a middleware must translate it into the context.
    101 type unsignedPayload struct{}
    102 
    103 // AddUnsignedPayloadMiddleware adds unsignedPayload to the operation
    104 // middleware stack
    105 func AddUnsignedPayloadMiddleware(stack *middleware.Stack) error {
    106 	return stack.Build.Add(&unsignedPayload{}, middleware.After)
    107 }
    108 
    109 // ID returns the unsignedPayload identifier
    110 func (m *unsignedPayload) ID() string {
    111 	return computePayloadHashMiddlewareID
    112 }
    113 
    114 // HandleBuild sets the payload hash to be an unsigned payload
    115 func (m *unsignedPayload) HandleBuild(
    116 	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
    117 ) (
    118 	out middleware.BuildOutput, metadata middleware.Metadata, err error,
    119 ) {
    120 	// This should not compute the content SHA256 if the value is already
    121 	// known. (e.g. application pre-computed SHA256 before making API call).
    122 	// Does not have any tight coupling to the X-Amz-Content-Sha256 header, if
    123 	// that header is provided a middleware must translate it into the context.
    124 	contentSHA := GetPayloadHash(ctx)
    125 	if len(contentSHA) == 0 {
    126 		contentSHA = v4Internal.UnsignedPayload
    127 	}
    128 
    129 	ctx = SetPayloadHash(ctx, contentSHA)
    130 	return next.HandleBuild(ctx, in)
    131 }
    132 
    133 // computePayloadSHA256 computes SHA256 payload hash to sign.
    134 //
    135 // Will not set the Unsigned Payload magic SHA value, if a SHA has already been
    136 // stored in the context. (e.g. application pre-computed SHA256 before making
    137 // API call).
    138 //
    139 // This middleware does not check the X-Amz-Content-Sha256 header, if that
    140 // header is serialized a middleware must translate it into the context.
    141 type computePayloadSHA256 struct{}
    142 
    143 // AddComputePayloadSHA256Middleware adds computePayloadSHA256 to the
    144 // operation middleware stack
    145 func AddComputePayloadSHA256Middleware(stack *middleware.Stack) error {
    146 	return stack.Build.Add(&computePayloadSHA256{}, middleware.After)
    147 }
    148 
    149 // RemoveComputePayloadSHA256Middleware removes computePayloadSHA256 from the
    150 // operation middleware stack
    151 func RemoveComputePayloadSHA256Middleware(stack *middleware.Stack) error {
    152 	_, err := stack.Build.Remove(computePayloadHashMiddlewareID)
    153 	return err
    154 }
    155 
    156 // ID is the middleware name
    157 func (m *computePayloadSHA256) ID() string {
    158 	return computePayloadHashMiddlewareID
    159 }
    160 
    161 // HandleBuild compute the payload hash for the request payload
    162 func (m *computePayloadSHA256) HandleBuild(
    163 	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
    164 ) (
    165 	out middleware.BuildOutput, metadata middleware.Metadata, err error,
    166 ) {
    167 	req, ok := in.Request.(*smithyhttp.Request)
    168 	if !ok {
    169 		return out, metadata, &HashComputationError{
    170 			Err: fmt.Errorf("unexpected request middleware type %T", in.Request),
    171 		}
    172 	}
    173 
    174 	// This should not compute the content SHA256 if the value is already
    175 	// known. (e.g. application pre-computed SHA256 before making API call)
    176 	// Does not have any tight coupling to the X-Amz-Content-Sha256 header, if
    177 	// that header is provided a middleware must translate it into the context.
    178 	if contentSHA := GetPayloadHash(ctx); len(contentSHA) != 0 {
    179 		return next.HandleBuild(ctx, in)
    180 	}
    181 
    182 	hash := sha256.New()
    183 	if stream := req.GetStream(); stream != nil {
    184 		_, err = io.Copy(hash, stream)
    185 		if err != nil {
    186 			return out, metadata, &HashComputationError{
    187 				Err: fmt.Errorf("failed to compute payload hash, %w", err),
    188 			}
    189 		}
    190 
    191 		if err := req.RewindStream(); err != nil {
    192 			return out, metadata, &HashComputationError{
    193 				Err: fmt.Errorf("failed to seek body to start, %w", err),
    194 			}
    195 		}
    196 	}
    197 
    198 	ctx = SetPayloadHash(ctx, hex.EncodeToString(hash.Sum(nil)))
    199 
    200 	return next.HandleBuild(ctx, in)
    201 }
    202 
    203 // SwapComputePayloadSHA256ForUnsignedPayloadMiddleware replaces the
    204 // ComputePayloadSHA256 middleware with the UnsignedPayload middleware.
    205 //
    206 // Use this to disable computing the Payload SHA256 checksum and instead use
    207 // UNSIGNED-PAYLOAD for the SHA256 value.
    208 func SwapComputePayloadSHA256ForUnsignedPayloadMiddleware(stack *middleware.Stack) error {
    209 	_, err := stack.Build.Swap(computePayloadHashMiddlewareID, &unsignedPayload{})
    210 	return err
    211 }
    212 
    213 // contentSHA256Header sets the X-Amz-Content-Sha256 header value to
    214 // the Payload hash stored in the context.
    215 type contentSHA256Header struct{}
    216 
    217 // AddContentSHA256HeaderMiddleware adds ContentSHA256Header to the
    218 // operation middleware stack
    219 func AddContentSHA256HeaderMiddleware(stack *middleware.Stack) error {
    220 	return stack.Build.Insert(&contentSHA256Header{}, computePayloadHashMiddlewareID, middleware.After)
    221 }
    222 
    223 // RemoveContentSHA256HeaderMiddleware removes contentSHA256Header middleware
    224 // from the operation middleware stack
    225 func RemoveContentSHA256HeaderMiddleware(stack *middleware.Stack) error {
    226 	_, err := stack.Build.Remove((*contentSHA256Header)(nil).ID())
    227 	return err
    228 }
    229 
    230 // ID returns the ContentSHA256HeaderMiddleware identifier
    231 func (m *contentSHA256Header) ID() string {
    232 	return "SigV4ContentSHA256Header"
    233 }
    234 
    235 // HandleBuild sets the X-Amz-Content-Sha256 header value to the Payload hash
    236 // stored in the context.
    237 func (m *contentSHA256Header) HandleBuild(
    238 	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
    239 ) (
    240 	out middleware.BuildOutput, metadata middleware.Metadata, err error,
    241 ) {
    242 	req, ok := in.Request.(*smithyhttp.Request)
    243 	if !ok {
    244 		return out, metadata, &HashComputationError{Err: fmt.Errorf("unexpected request middleware type %T", in.Request)}
    245 	}
    246 
    247 	req.Header.Set(v4Internal.ContentSHAKey, GetPayloadHash(ctx))
    248 
    249 	return next.HandleBuild(ctx, in)
    250 }
    251 
    252 // SignHTTPRequestMiddlewareOptions is the configuration options for the SignHTTPRequestMiddleware middleware.
    253 type SignHTTPRequestMiddlewareOptions struct {
    254 	CredentialsProvider aws.CredentialsProvider
    255 	Signer              HTTPSigner
    256 	LogSigning          bool
    257 }
    258 
    259 // SignHTTPRequestMiddleware is a `FinalizeMiddleware` implementation for SigV4 HTTP Signing
    260 type SignHTTPRequestMiddleware struct {
    261 	credentialsProvider aws.CredentialsProvider
    262 	signer              HTTPSigner
    263 	logSigning          bool
    264 }
    265 
    266 // NewSignHTTPRequestMiddleware constructs a SignHTTPRequestMiddleware using the given Signer for signing requests
    267 func NewSignHTTPRequestMiddleware(options SignHTTPRequestMiddlewareOptions) *SignHTTPRequestMiddleware {
    268 	return &SignHTTPRequestMiddleware{
    269 		credentialsProvider: options.CredentialsProvider,
    270 		signer:              options.Signer,
    271 		logSigning:          options.LogSigning,
    272 	}
    273 }
    274 
    275 // ID is the SignHTTPRequestMiddleware identifier
    276 func (s *SignHTTPRequestMiddleware) ID() string {
    277 	return "Signing"
    278 }
    279 
    280 // HandleFinalize will take the provided input and sign the request using the SigV4 authentication scheme
    281 func (s *SignHTTPRequestMiddleware) HandleFinalize(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (
    282 	out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
    283 ) {
    284 	if !haveCredentialProvider(s.credentialsProvider) {
    285 		return next.HandleFinalize(ctx, in)
    286 	}
    287 
    288 	req, ok := in.Request.(*smithyhttp.Request)
    289 	if !ok {
    290 		return out, metadata, &SigningError{Err: fmt.Errorf("unexpected request middleware type %T", in.Request)}
    291 	}
    292 
    293 	signingName, signingRegion := awsmiddleware.GetSigningName(ctx), awsmiddleware.GetSigningRegion(ctx)
    294 	payloadHash := GetPayloadHash(ctx)
    295 	if len(payloadHash) == 0 {
    296 		return out, metadata, &SigningError{Err: fmt.Errorf("computed payload hash missing from context")}
    297 	}
    298 
    299 	credentials, err := s.credentialsProvider.Retrieve(ctx)
    300 	if err != nil {
    301 		return out, metadata, &SigningError{Err: fmt.Errorf("failed to retrieve credentials: %w", err)}
    302 	}
    303 
    304 	err = s.signer.SignHTTP(ctx, credentials, req.Request, payloadHash, signingName, signingRegion, sdk.NowTime(),
    305 		func(o *SignerOptions) {
    306 			o.Logger = middleware.GetLogger(ctx)
    307 			o.LogSigning = s.logSigning
    308 		})
    309 	if err != nil {
    310 		return out, metadata, &SigningError{Err: fmt.Errorf("failed to sign http request, %w", err)}
    311 	}
    312 
    313 	ctx = awsmiddleware.SetSigningCredentials(ctx, credentials)
    314 
    315 	return next.HandleFinalize(ctx, in)
    316 }
    317 
    318 type streamingEventsPayload struct{}
    319 
    320 // AddStreamingEventsPayload adds the streamingEventsPayload middleware to the stack.
    321 func AddStreamingEventsPayload(stack *middleware.Stack) error {
    322 	return stack.Build.Add(&streamingEventsPayload{}, middleware.After)
    323 }
    324 
    325 func (s *streamingEventsPayload) ID() string {
    326 	return computePayloadHashMiddlewareID
    327 }
    328 
    329 func (s *streamingEventsPayload) HandleBuild(
    330 	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
    331 ) (
    332 	out middleware.BuildOutput, metadata middleware.Metadata, err error,
    333 ) {
    334 	contentSHA := GetPayloadHash(ctx)
    335 	if len(contentSHA) == 0 {
    336 		contentSHA = v4Internal.StreamingEventsPayload
    337 	}
    338 
    339 	ctx = SetPayloadHash(ctx, contentSHA)
    340 
    341 	return next.HandleBuild(ctx, in)
    342 }
    343 
    344 // GetSignedRequestSignature attempts to extract the signature of the request.
    345 // Returning an error if the request is unsigned, or unable to extract the
    346 // signature.
    347 func GetSignedRequestSignature(r *http.Request) ([]byte, error) {
    348 	const authHeaderSignatureElem = "Signature="
    349 
    350 	if auth := r.Header.Get(authorizationHeader); len(auth) != 0 {
    351 		ps := strings.Split(auth, ", ")
    352 		for _, p := range ps {
    353 			if idx := strings.Index(p, authHeaderSignatureElem); idx >= 0 {
    354 				sig := p[len(authHeaderSignatureElem):]
    355 				if len(sig) == 0 {
    356 					return nil, fmt.Errorf("invalid request signature authorization header")
    357 				}
    358 				return hex.DecodeString(sig)
    359 			}
    360 		}
    361 	}
    362 
    363 	if sig := r.URL.Query().Get("X-Amz-Signature"); len(sig) != 0 {
    364 		return hex.DecodeString(sig)
    365 	}
    366 
    367 	return nil, fmt.Errorf("request not signed")
    368 }
    369 
    370 func haveCredentialProvider(p aws.CredentialsProvider) bool {
    371 	if p == nil {
    372 		return false
    373 	}
    374 	switch p.(type) {
    375 	case aws.AnonymousCredentials,
    376 		*aws.AnonymousCredentials:
    377 		return false
    378 	}
    379 
    380 	return true
    381 }
    382 
    383 type payloadHashKey struct{}
    384 
    385 // GetPayloadHash retrieves the payload hash to use for signing
    386 //
    387 // Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
    388 // to clear all stack values.
    389 func GetPayloadHash(ctx context.Context) (v string) {
    390 	v, _ = middleware.GetStackValue(ctx, payloadHashKey{}).(string)
    391 	return v
    392 }
    393 
    394 // SetPayloadHash sets the payload hash to be used for signing the request
    395 //
    396 // Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
    397 // to clear all stack values.
    398 func SetPayloadHash(ctx context.Context, hash string) context.Context {
    399 	return middleware.WithStackValue(ctx, payloadHashKey{}, hash)
    400 }