provider.go (6875B)
1 package ec2rolecreds 2 3 import ( 4 "bufio" 5 "context" 6 "encoding/json" 7 "fmt" 8 "math" 9 "path" 10 "strings" 11 "time" 12 13 "github.com/aws/aws-sdk-go-v2/aws" 14 "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" 15 sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand" 16 "github.com/aws/aws-sdk-go-v2/internal/sdk" 17 "github.com/aws/smithy-go" 18 "github.com/aws/smithy-go/logging" 19 "github.com/aws/smithy-go/middleware" 20 ) 21 22 // ProviderName provides a name of EC2Role provider 23 const ProviderName = "EC2RoleProvider" 24 25 // GetMetadataAPIClient provides the interface for an EC2 IMDS API client for the 26 // GetMetadata operation. 27 type GetMetadataAPIClient interface { 28 GetMetadata(context.Context, *imds.GetMetadataInput, ...func(*imds.Options)) (*imds.GetMetadataOutput, error) 29 } 30 31 // A Provider retrieves credentials from the EC2 service, and keeps track if 32 // those credentials are expired. 33 // 34 // The New function must be used to create the with a custom EC2 IMDS client. 35 // 36 // p := &ec2rolecreds.New(func(o *ec2rolecreds.Options{ 37 // o.Client = imds.New(imds.Options{/* custom options */}) 38 // }) 39 type Provider struct { 40 options Options 41 } 42 43 // Options is a list of user settable options for setting the behavior of the Provider. 44 type Options struct { 45 // The API client that will be used by the provider to make GetMetadata API 46 // calls to EC2 IMDS. 47 // 48 // If nil, the provider will default to the EC2 IMDS client. 49 Client GetMetadataAPIClient 50 } 51 52 // New returns an initialized Provider value configured to retrieve 53 // credentials from EC2 Instance Metadata service. 54 func New(optFns ...func(*Options)) *Provider { 55 options := Options{} 56 57 for _, fn := range optFns { 58 fn(&options) 59 } 60 61 if options.Client == nil { 62 options.Client = imds.New(imds.Options{}) 63 } 64 65 return &Provider{ 66 options: options, 67 } 68 } 69 70 // Retrieve retrieves credentials from the EC2 service. Error will be returned 71 // if the request fails, or unable to extract the desired credentials. 72 func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) { 73 credsList, err := requestCredList(ctx, p.options.Client) 74 if err != nil { 75 return aws.Credentials{Source: ProviderName}, err 76 } 77 78 if len(credsList) == 0 { 79 return aws.Credentials{Source: ProviderName}, 80 fmt.Errorf("unexpected empty EC2 IMDS role list") 81 } 82 credsName := credsList[0] 83 84 roleCreds, err := requestCred(ctx, p.options.Client, credsName) 85 if err != nil { 86 return aws.Credentials{Source: ProviderName}, err 87 } 88 89 creds := aws.Credentials{ 90 AccessKeyID: roleCreds.AccessKeyID, 91 SecretAccessKey: roleCreds.SecretAccessKey, 92 SessionToken: roleCreds.Token, 93 Source: ProviderName, 94 95 CanExpire: true, 96 Expires: roleCreds.Expiration, 97 } 98 99 // Cap role credentials Expires to 1 hour so they can be refreshed more 100 // often. Jitter will be applied credentials cache if being used. 101 if anHour := sdk.NowTime().Add(1 * time.Hour); creds.Expires.After(anHour) { 102 creds.Expires = anHour 103 } 104 105 return creds, nil 106 } 107 108 // HandleFailToRefresh will extend the credentials Expires time if it it is 109 // expired. If the credentials will not expire within the minimum time, they 110 // will be returned. 111 // 112 // If the credentials cannot expire, the original error will be returned. 113 func (p *Provider) HandleFailToRefresh(ctx context.Context, prevCreds aws.Credentials, err error) ( 114 aws.Credentials, error, 115 ) { 116 if !prevCreds.CanExpire { 117 return aws.Credentials{}, err 118 } 119 120 if prevCreds.Expires.After(sdk.NowTime().Add(5 * time.Minute)) { 121 return prevCreds, nil 122 } 123 124 newCreds := prevCreds 125 randFloat64, err := sdkrand.CryptoRandFloat64() 126 if err != nil { 127 return aws.Credentials{}, fmt.Errorf("failed to get random float, %w", err) 128 } 129 130 // Random distribution of [5,15) minutes. 131 expireOffset := time.Duration(randFloat64*float64(10*time.Minute)) + 5*time.Minute 132 newCreds.Expires = sdk.NowTime().Add(expireOffset) 133 134 logger := middleware.GetLogger(ctx) 135 logger.Logf(logging.Warn, "Attempting credential expiration extension due to a credential service availability issue. A refresh of these credentials will be attempted again in %v minutes.", math.Floor(expireOffset.Minutes())) 136 137 return newCreds, nil 138 } 139 140 // AdjustExpiresBy will adds the passed in duration to the passed in 141 // credential's Expires time, unless the time until Expires is less than 15 142 // minutes. Returns the credentials, even if not updated. 143 func (p *Provider) AdjustExpiresBy(creds aws.Credentials, dur time.Duration) ( 144 aws.Credentials, error, 145 ) { 146 if !creds.CanExpire { 147 return creds, nil 148 } 149 if creds.Expires.Before(sdk.NowTime().Add(15 * time.Minute)) { 150 return creds, nil 151 } 152 153 creds.Expires = creds.Expires.Add(dur) 154 return creds, nil 155 } 156 157 // ec2RoleCredRespBody provides the shape for unmarshaling credential 158 // request responses. 159 type ec2RoleCredRespBody struct { 160 // Success State 161 Expiration time.Time 162 AccessKeyID string 163 SecretAccessKey string 164 Token string 165 166 // Error state 167 Code string 168 Message string 169 } 170 171 const iamSecurityCredsPath = "/iam/security-credentials/" 172 173 // requestCredList requests a list of credentials from the EC2 service. If 174 // there are no credentials, or there is an error making or receiving the 175 // request 176 func requestCredList(ctx context.Context, client GetMetadataAPIClient) ([]string, error) { 177 resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{ 178 Path: iamSecurityCredsPath, 179 }) 180 if err != nil { 181 return nil, fmt.Errorf("no EC2 IMDS role found, %w", err) 182 } 183 defer resp.Content.Close() 184 185 credsList := []string{} 186 s := bufio.NewScanner(resp.Content) 187 for s.Scan() { 188 credsList = append(credsList, s.Text()) 189 } 190 191 if err := s.Err(); err != nil { 192 return nil, fmt.Errorf("failed to read EC2 IMDS role, %w", err) 193 } 194 195 return credsList, nil 196 } 197 198 // requestCred requests the credentials for a specific credentials from the EC2 service. 199 // 200 // If the credentials cannot be found, or there is an error reading the response 201 // and error will be returned. 202 func requestCred(ctx context.Context, client GetMetadataAPIClient, credsName string) (ec2RoleCredRespBody, error) { 203 resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{ 204 Path: path.Join(iamSecurityCredsPath, credsName), 205 }) 206 if err != nil { 207 return ec2RoleCredRespBody{}, 208 fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w", 209 credsName, err) 210 } 211 defer resp.Content.Close() 212 213 var respCreds ec2RoleCredRespBody 214 if err := json.NewDecoder(resp.Content).Decode(&respCreds); err != nil { 215 return ec2RoleCredRespBody{}, 216 fmt.Errorf("failed to decode %s EC2 IMDS role credentials, %w", 217 credsName, err) 218 } 219 220 if !strings.EqualFold(respCreds.Code, "Success") { 221 // If an error code was returned something failed requesting the role. 222 return ec2RoleCredRespBody{}, 223 fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w", 224 credsName, 225 &smithy.GenericAPIError{Code: respCreds.Code, Message: respCreds.Message}) 226 } 227 228 return respCreds, nil 229 }