provider.go (7455B)
1 package ec2rolecreds 2 3 import ( 4 "bufio" 5 "context" 6 "encoding/json" 7 "fmt" 8 "math" 9 "path" 10 "strings" 11 "time" 12 13 "github.com/aws/aws-sdk-go-v2/aws" 14 "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" 15 sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand" 16 "github.com/aws/aws-sdk-go-v2/internal/sdk" 17 "github.com/aws/smithy-go" 18 "github.com/aws/smithy-go/logging" 19 "github.com/aws/smithy-go/middleware" 20 ) 21 22 // ProviderName provides a name of EC2Role provider 23 const ProviderName = "EC2RoleProvider" 24 25 // GetMetadataAPIClient provides the interface for an EC2 IMDS API client for the 26 // GetMetadata operation. 27 type GetMetadataAPIClient interface { 28 GetMetadata(context.Context, *imds.GetMetadataInput, ...func(*imds.Options)) (*imds.GetMetadataOutput, error) 29 } 30 31 // A Provider retrieves credentials from the EC2 service, and keeps track if 32 // those credentials are expired. 33 // 34 // The New function must be used to create the with a custom EC2 IMDS client. 35 // 36 // p := &ec2rolecreds.New(func(o *ec2rolecreds.Options{ 37 // o.Client = imds.New(imds.Options{/* custom options */}) 38 // }) 39 type Provider struct { 40 options Options 41 } 42 43 // Options is a list of user settable options for setting the behavior of the Provider. 44 type Options struct { 45 // The API client that will be used by the provider to make GetMetadata API 46 // calls to EC2 IMDS. 47 // 48 // If nil, the provider will default to the EC2 IMDS client. 49 Client GetMetadataAPIClient 50 51 // The chain of providers that was used to create this provider 52 // These values are for reporting purposes and are not meant to be set up directly 53 CredentialSources []aws.CredentialSource 54 } 55 56 // New returns an initialized Provider value configured to retrieve 57 // credentials from EC2 Instance Metadata service. 58 func New(optFns ...func(*Options)) *Provider { 59 options := Options{} 60 61 for _, fn := range optFns { 62 fn(&options) 63 } 64 65 if options.Client == nil { 66 options.Client = imds.New(imds.Options{}) 67 } 68 69 return &Provider{ 70 options: options, 71 } 72 } 73 74 // Retrieve retrieves credentials from the EC2 service. Error will be returned 75 // if the request fails, or unable to extract the desired credentials. 76 func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) { 77 credsList, err := requestCredList(ctx, p.options.Client) 78 if err != nil { 79 return aws.Credentials{Source: ProviderName}, err 80 } 81 82 if len(credsList) == 0 { 83 return aws.Credentials{Source: ProviderName}, 84 fmt.Errorf("unexpected empty EC2 IMDS role list") 85 } 86 credsName := credsList[0] 87 88 roleCreds, err := requestCred(ctx, p.options.Client, credsName) 89 if err != nil { 90 return aws.Credentials{Source: ProviderName}, err 91 } 92 93 creds := aws.Credentials{ 94 AccessKeyID: roleCreds.AccessKeyID, 95 SecretAccessKey: roleCreds.SecretAccessKey, 96 SessionToken: roleCreds.Token, 97 Source: ProviderName, 98 99 CanExpire: true, 100 Expires: roleCreds.Expiration, 101 } 102 103 // Cap role credentials Expires to 1 hour so they can be refreshed more 104 // often. Jitter will be applied credentials cache if being used. 105 if anHour := sdk.NowTime().Add(1 * time.Hour); creds.Expires.After(anHour) { 106 creds.Expires = anHour 107 } 108 109 return creds, nil 110 } 111 112 // HandleFailToRefresh will extend the credentials Expires time if it it is 113 // expired. If the credentials will not expire within the minimum time, they 114 // will be returned. 115 // 116 // If the credentials cannot expire, the original error will be returned. 117 func (p *Provider) HandleFailToRefresh(ctx context.Context, prevCreds aws.Credentials, err error) ( 118 aws.Credentials, error, 119 ) { 120 if !prevCreds.CanExpire { 121 return aws.Credentials{}, err 122 } 123 124 if prevCreds.Expires.After(sdk.NowTime().Add(5 * time.Minute)) { 125 return prevCreds, nil 126 } 127 128 newCreds := prevCreds 129 randFloat64, err := sdkrand.CryptoRandFloat64() 130 if err != nil { 131 return aws.Credentials{}, fmt.Errorf("failed to get random float, %w", err) 132 } 133 134 // Random distribution of [5,15) minutes. 135 expireOffset := time.Duration(randFloat64*float64(10*time.Minute)) + 5*time.Minute 136 newCreds.Expires = sdk.NowTime().Add(expireOffset) 137 138 logger := middleware.GetLogger(ctx) 139 logger.Logf(logging.Warn, "Attempting credential expiration extension due to a credential service availability issue. A refresh of these credentials will be attempted again in %v minutes.", math.Floor(expireOffset.Minutes())) 140 141 return newCreds, nil 142 } 143 144 // AdjustExpiresBy will adds the passed in duration to the passed in 145 // credential's Expires time, unless the time until Expires is less than 15 146 // minutes. Returns the credentials, even if not updated. 147 func (p *Provider) AdjustExpiresBy(creds aws.Credentials, dur time.Duration) ( 148 aws.Credentials, error, 149 ) { 150 if !creds.CanExpire { 151 return creds, nil 152 } 153 if creds.Expires.Before(sdk.NowTime().Add(15 * time.Minute)) { 154 return creds, nil 155 } 156 157 creds.Expires = creds.Expires.Add(dur) 158 return creds, nil 159 } 160 161 // ec2RoleCredRespBody provides the shape for unmarshaling credential 162 // request responses. 163 type ec2RoleCredRespBody struct { 164 // Success State 165 Expiration time.Time 166 AccessKeyID string 167 SecretAccessKey string 168 Token string 169 170 // Error state 171 Code string 172 Message string 173 } 174 175 const iamSecurityCredsPath = "/iam/security-credentials/" 176 177 // requestCredList requests a list of credentials from the EC2 service. If 178 // there are no credentials, or there is an error making or receiving the 179 // request 180 func requestCredList(ctx context.Context, client GetMetadataAPIClient) ([]string, error) { 181 resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{ 182 Path: iamSecurityCredsPath, 183 }) 184 if err != nil { 185 return nil, fmt.Errorf("no EC2 IMDS role found, %w", err) 186 } 187 defer resp.Content.Close() 188 189 credsList := []string{} 190 s := bufio.NewScanner(resp.Content) 191 for s.Scan() { 192 credsList = append(credsList, s.Text()) 193 } 194 195 if err := s.Err(); err != nil { 196 return nil, fmt.Errorf("failed to read EC2 IMDS role, %w", err) 197 } 198 199 return credsList, nil 200 } 201 202 // requestCred requests the credentials for a specific credentials from the EC2 service. 203 // 204 // If the credentials cannot be found, or there is an error reading the response 205 // and error will be returned. 206 func requestCred(ctx context.Context, client GetMetadataAPIClient, credsName string) (ec2RoleCredRespBody, error) { 207 resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{ 208 Path: path.Join(iamSecurityCredsPath, credsName), 209 }) 210 if err != nil { 211 return ec2RoleCredRespBody{}, 212 fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w", 213 credsName, err) 214 } 215 defer resp.Content.Close() 216 217 var respCreds ec2RoleCredRespBody 218 if err := json.NewDecoder(resp.Content).Decode(&respCreds); err != nil { 219 return ec2RoleCredRespBody{}, 220 fmt.Errorf("failed to decode %s EC2 IMDS role credentials, %w", 221 credsName, err) 222 } 223 224 if !strings.EqualFold(respCreds.Code, "Success") { 225 // If an error code was returned something failed requesting the role. 226 return ec2RoleCredRespBody{}, 227 fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w", 228 credsName, 229 &smithy.GenericAPIError{Code: respCreds.Code, Message: respCreds.Message}) 230 } 231 232 return respCreds, nil 233 } 234 235 // ProviderSources returns the credential chain that was used to construct this provider 236 func (p *Provider) ProviderSources() []aws.CredentialSource { 237 if p.options.CredentialSources == nil { 238 return []aws.CredentialSource{aws.CredentialSourceIMDS} 239 } // If no source has been set, assume this is used directly which means just call to assume role 240 return p.options.CredentialSources 241 }