recursion_detection.go (2447B)
1 package middleware 2 3 import ( 4 "context" 5 "fmt" 6 "github.com/aws/smithy-go/middleware" 7 smithyhttp "github.com/aws/smithy-go/transport/http" 8 "os" 9 ) 10 11 const envAwsLambdaFunctionName = "AWS_LAMBDA_FUNCTION_NAME" 12 const envAmznTraceID = "_X_AMZN_TRACE_ID" 13 const amznTraceIDHeader = "X-Amzn-Trace-Id" 14 15 // AddRecursionDetection adds recursionDetection to the middleware stack 16 func AddRecursionDetection(stack *middleware.Stack) error { 17 return stack.Build.Add(&RecursionDetection{}, middleware.After) 18 } 19 20 // RecursionDetection detects Lambda environment and sets its X-Ray trace ID to request header if absent 21 // to avoid recursion invocation in Lambda 22 type RecursionDetection struct{} 23 24 // ID returns the middleware identifier 25 func (m *RecursionDetection) ID() string { 26 return "RecursionDetection" 27 } 28 29 // HandleBuild detects Lambda environment and adds its trace ID to request header if absent 30 func (m *RecursionDetection) HandleBuild( 31 ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler, 32 ) ( 33 out middleware.BuildOutput, metadata middleware.Metadata, err error, 34 ) { 35 req, ok := in.Request.(*smithyhttp.Request) 36 if !ok { 37 return out, metadata, fmt.Errorf("unknown request type %T", req) 38 } 39 40 _, hasLambdaEnv := os.LookupEnv(envAwsLambdaFunctionName) 41 xAmznTraceID, hasTraceID := os.LookupEnv(envAmznTraceID) 42 value := req.Header.Get(amznTraceIDHeader) 43 // only set the X-Amzn-Trace-Id header when it is not set initially, the 44 // current environment is Lambda and the _X_AMZN_TRACE_ID env variable exists 45 if value != "" || !hasLambdaEnv || !hasTraceID { 46 return next.HandleBuild(ctx, in) 47 } 48 49 req.Header.Set(amznTraceIDHeader, percentEncode(xAmznTraceID)) 50 return next.HandleBuild(ctx, in) 51 } 52 53 func percentEncode(s string) string { 54 upperhex := "0123456789ABCDEF" 55 hexCount := 0 56 for i := 0; i < len(s); i++ { 57 c := s[i] 58 if shouldEncode(c) { 59 hexCount++ 60 } 61 } 62 63 if hexCount == 0 { 64 return s 65 } 66 67 required := len(s) + 2*hexCount 68 t := make([]byte, required) 69 j := 0 70 for i := 0; i < len(s); i++ { 71 if c := s[i]; shouldEncode(c) { 72 t[j] = '%' 73 t[j+1] = upperhex[c>>4] 74 t[j+2] = upperhex[c&15] 75 j += 3 76 } else { 77 t[j] = c 78 j++ 79 } 80 } 81 return string(t) 82 } 83 84 func shouldEncode(c byte) bool { 85 if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' { 86 return false 87 } 88 switch c { 89 case '-', '=', ';', ':', '+', '&', '[', ']', '{', '}', '"', '\'', ',': 90 return false 91 default: 92 return true 93 } 94 }