src

Go monorepo.
git clone git://code.dwrz.net/src
Log | Files | Refs

recursion_detection.go (2447B)


      1 package middleware
      2 
      3 import (
      4 	"context"
      5 	"fmt"
      6 	"github.com/aws/smithy-go/middleware"
      7 	smithyhttp "github.com/aws/smithy-go/transport/http"
      8 	"os"
      9 )
     10 
     11 const envAwsLambdaFunctionName = "AWS_LAMBDA_FUNCTION_NAME"
     12 const envAmznTraceID = "_X_AMZN_TRACE_ID"
     13 const amznTraceIDHeader = "X-Amzn-Trace-Id"
     14 
     15 // AddRecursionDetection adds recursionDetection to the middleware stack
     16 func AddRecursionDetection(stack *middleware.Stack) error {
     17 	return stack.Build.Add(&RecursionDetection{}, middleware.After)
     18 }
     19 
     20 // RecursionDetection detects Lambda environment and sets its X-Ray trace ID to request header if absent
     21 // to avoid recursion invocation in Lambda
     22 type RecursionDetection struct{}
     23 
     24 // ID returns the middleware identifier
     25 func (m *RecursionDetection) ID() string {
     26 	return "RecursionDetection"
     27 }
     28 
     29 // HandleBuild detects Lambda environment and adds its trace ID to request header if absent
     30 func (m *RecursionDetection) HandleBuild(
     31 	ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
     32 ) (
     33 	out middleware.BuildOutput, metadata middleware.Metadata, err error,
     34 ) {
     35 	req, ok := in.Request.(*smithyhttp.Request)
     36 	if !ok {
     37 		return out, metadata, fmt.Errorf("unknown request type %T", req)
     38 	}
     39 
     40 	_, hasLambdaEnv := os.LookupEnv(envAwsLambdaFunctionName)
     41 	xAmznTraceID, hasTraceID := os.LookupEnv(envAmznTraceID)
     42 	value := req.Header.Get(amznTraceIDHeader)
     43 	// only set the X-Amzn-Trace-Id header when it is not set initially, the
     44 	// current environment is Lambda and the _X_AMZN_TRACE_ID env variable exists
     45 	if value != "" || !hasLambdaEnv || !hasTraceID {
     46 		return next.HandleBuild(ctx, in)
     47 	}
     48 
     49 	req.Header.Set(amznTraceIDHeader, percentEncode(xAmznTraceID))
     50 	return next.HandleBuild(ctx, in)
     51 }
     52 
     53 func percentEncode(s string) string {
     54 	upperhex := "0123456789ABCDEF"
     55 	hexCount := 0
     56 	for i := 0; i < len(s); i++ {
     57 		c := s[i]
     58 		if shouldEncode(c) {
     59 			hexCount++
     60 		}
     61 	}
     62 
     63 	if hexCount == 0 {
     64 		return s
     65 	}
     66 
     67 	required := len(s) + 2*hexCount
     68 	t := make([]byte, required)
     69 	j := 0
     70 	for i := 0; i < len(s); i++ {
     71 		if c := s[i]; shouldEncode(c) {
     72 			t[j] = '%'
     73 			t[j+1] = upperhex[c>>4]
     74 			t[j+2] = upperhex[c&15]
     75 			j += 3
     76 		} else {
     77 			t[j] = c
     78 			j++
     79 		}
     80 	}
     81 	return string(t)
     82 }
     83 
     84 func shouldEncode(c byte) bool {
     85 	if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' {
     86 		return false
     87 	}
     88 	switch c {
     89 	case '-', '=', ';', ':', '+', '&', '[', ']', '{', '}', '"', '\'', ',':
     90 		return false
     91 	default:
     92 		return true
     93 	}
     94 }