src

Go monorepo.
git clone git://code.dwrz.net/src
Log | Files | Refs

provider.go (7455B)


      1 package ec2rolecreds
      2 
      3 import (
      4 	"bufio"
      5 	"context"
      6 	"encoding/json"
      7 	"fmt"
      8 	"math"
      9 	"path"
     10 	"strings"
     11 	"time"
     12 
     13 	"github.com/aws/aws-sdk-go-v2/aws"
     14 	"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
     15 	sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
     16 	"github.com/aws/aws-sdk-go-v2/internal/sdk"
     17 	"github.com/aws/smithy-go"
     18 	"github.com/aws/smithy-go/logging"
     19 	"github.com/aws/smithy-go/middleware"
     20 )
     21 
     22 // ProviderName provides a name of EC2Role provider
     23 const ProviderName = "EC2RoleProvider"
     24 
     25 // GetMetadataAPIClient provides the interface for an EC2 IMDS API client for the
     26 // GetMetadata operation.
     27 type GetMetadataAPIClient interface {
     28 	GetMetadata(context.Context, *imds.GetMetadataInput, ...func(*imds.Options)) (*imds.GetMetadataOutput, error)
     29 }
     30 
     31 // A Provider retrieves credentials from the EC2 service, and keeps track if
     32 // those credentials are expired.
     33 //
     34 // The New function must be used to create the with a custom EC2 IMDS client.
     35 //
     36 //	p := &ec2rolecreds.New(func(o *ec2rolecreds.Options{
     37 //	     o.Client = imds.New(imds.Options{/* custom options */})
     38 //	})
     39 type Provider struct {
     40 	options Options
     41 }
     42 
     43 // Options is a list of user settable options for setting the behavior of the Provider.
     44 type Options struct {
     45 	// The API client that will be used by the provider to make GetMetadata API
     46 	// calls to EC2 IMDS.
     47 	//
     48 	// If nil, the provider will default to the EC2 IMDS client.
     49 	Client GetMetadataAPIClient
     50 
     51 	// The chain of providers that was used to create this provider
     52 	// These values are for reporting purposes and are not meant to be set up directly
     53 	CredentialSources []aws.CredentialSource
     54 }
     55 
     56 // New returns an initialized Provider value configured to retrieve
     57 // credentials from EC2 Instance Metadata service.
     58 func New(optFns ...func(*Options)) *Provider {
     59 	options := Options{}
     60 
     61 	for _, fn := range optFns {
     62 		fn(&options)
     63 	}
     64 
     65 	if options.Client == nil {
     66 		options.Client = imds.New(imds.Options{})
     67 	}
     68 
     69 	return &Provider{
     70 		options: options,
     71 	}
     72 }
     73 
     74 // Retrieve retrieves credentials from the EC2 service. Error will be returned
     75 // if the request fails, or unable to extract the desired credentials.
     76 func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
     77 	credsList, err := requestCredList(ctx, p.options.Client)
     78 	if err != nil {
     79 		return aws.Credentials{Source: ProviderName}, err
     80 	}
     81 
     82 	if len(credsList) == 0 {
     83 		return aws.Credentials{Source: ProviderName},
     84 			fmt.Errorf("unexpected empty EC2 IMDS role list")
     85 	}
     86 	credsName := credsList[0]
     87 
     88 	roleCreds, err := requestCred(ctx, p.options.Client, credsName)
     89 	if err != nil {
     90 		return aws.Credentials{Source: ProviderName}, err
     91 	}
     92 
     93 	creds := aws.Credentials{
     94 		AccessKeyID:     roleCreds.AccessKeyID,
     95 		SecretAccessKey: roleCreds.SecretAccessKey,
     96 		SessionToken:    roleCreds.Token,
     97 		Source:          ProviderName,
     98 
     99 		CanExpire: true,
    100 		Expires:   roleCreds.Expiration,
    101 	}
    102 
    103 	// Cap role credentials Expires to 1 hour so they can be refreshed more
    104 	// often. Jitter will be applied credentials cache if being used.
    105 	if anHour := sdk.NowTime().Add(1 * time.Hour); creds.Expires.After(anHour) {
    106 		creds.Expires = anHour
    107 	}
    108 
    109 	return creds, nil
    110 }
    111 
    112 // HandleFailToRefresh will extend the credentials Expires time if it it is
    113 // expired. If the credentials will not expire within the minimum time, they
    114 // will be returned.
    115 //
    116 // If the credentials cannot expire, the original error will be returned.
    117 func (p *Provider) HandleFailToRefresh(ctx context.Context, prevCreds aws.Credentials, err error) (
    118 	aws.Credentials, error,
    119 ) {
    120 	if !prevCreds.CanExpire {
    121 		return aws.Credentials{}, err
    122 	}
    123 
    124 	if prevCreds.Expires.After(sdk.NowTime().Add(5 * time.Minute)) {
    125 		return prevCreds, nil
    126 	}
    127 
    128 	newCreds := prevCreds
    129 	randFloat64, err := sdkrand.CryptoRandFloat64()
    130 	if err != nil {
    131 		return aws.Credentials{}, fmt.Errorf("failed to get random float, %w", err)
    132 	}
    133 
    134 	// Random distribution of [5,15) minutes.
    135 	expireOffset := time.Duration(randFloat64*float64(10*time.Minute)) + 5*time.Minute
    136 	newCreds.Expires = sdk.NowTime().Add(expireOffset)
    137 
    138 	logger := middleware.GetLogger(ctx)
    139 	logger.Logf(logging.Warn, "Attempting credential expiration extension due to a credential service availability issue. A refresh of these credentials will be attempted again in %v minutes.", math.Floor(expireOffset.Minutes()))
    140 
    141 	return newCreds, nil
    142 }
    143 
    144 // AdjustExpiresBy will adds the passed in duration to the passed in
    145 // credential's Expires time, unless the time until Expires is less than 15
    146 // minutes. Returns the credentials, even if not updated.
    147 func (p *Provider) AdjustExpiresBy(creds aws.Credentials, dur time.Duration) (
    148 	aws.Credentials, error,
    149 ) {
    150 	if !creds.CanExpire {
    151 		return creds, nil
    152 	}
    153 	if creds.Expires.Before(sdk.NowTime().Add(15 * time.Minute)) {
    154 		return creds, nil
    155 	}
    156 
    157 	creds.Expires = creds.Expires.Add(dur)
    158 	return creds, nil
    159 }
    160 
    161 // ec2RoleCredRespBody provides the shape for unmarshaling credential
    162 // request responses.
    163 type ec2RoleCredRespBody struct {
    164 	// Success State
    165 	Expiration      time.Time
    166 	AccessKeyID     string
    167 	SecretAccessKey string
    168 	Token           string
    169 
    170 	// Error state
    171 	Code    string
    172 	Message string
    173 }
    174 
    175 const iamSecurityCredsPath = "/iam/security-credentials/"
    176 
    177 // requestCredList requests a list of credentials from the EC2 service. If
    178 // there are no credentials, or there is an error making or receiving the
    179 // request
    180 func requestCredList(ctx context.Context, client GetMetadataAPIClient) ([]string, error) {
    181 	resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
    182 		Path: iamSecurityCredsPath,
    183 	})
    184 	if err != nil {
    185 		return nil, fmt.Errorf("no EC2 IMDS role found, %w", err)
    186 	}
    187 	defer resp.Content.Close()
    188 
    189 	credsList := []string{}
    190 	s := bufio.NewScanner(resp.Content)
    191 	for s.Scan() {
    192 		credsList = append(credsList, s.Text())
    193 	}
    194 
    195 	if err := s.Err(); err != nil {
    196 		return nil, fmt.Errorf("failed to read EC2 IMDS role, %w", err)
    197 	}
    198 
    199 	return credsList, nil
    200 }
    201 
    202 // requestCred requests the credentials for a specific credentials from the EC2 service.
    203 //
    204 // If the credentials cannot be found, or there is an error reading the response
    205 // and error will be returned.
    206 func requestCred(ctx context.Context, client GetMetadataAPIClient, credsName string) (ec2RoleCredRespBody, error) {
    207 	resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
    208 		Path: path.Join(iamSecurityCredsPath, credsName),
    209 	})
    210 	if err != nil {
    211 		return ec2RoleCredRespBody{},
    212 			fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
    213 				credsName, err)
    214 	}
    215 	defer resp.Content.Close()
    216 
    217 	var respCreds ec2RoleCredRespBody
    218 	if err := json.NewDecoder(resp.Content).Decode(&respCreds); err != nil {
    219 		return ec2RoleCredRespBody{},
    220 			fmt.Errorf("failed to decode %s EC2 IMDS role credentials, %w",
    221 				credsName, err)
    222 	}
    223 
    224 	if !strings.EqualFold(respCreds.Code, "Success") {
    225 		// If an error code was returned something failed requesting the role.
    226 		return ec2RoleCredRespBody{},
    227 			fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
    228 				credsName,
    229 				&smithy.GenericAPIError{Code: respCreds.Code, Message: respCreds.Message})
    230 	}
    231 
    232 	return respCreds, nil
    233 }
    234 
    235 // ProviderSources returns the credential chain that was used to construct this provider
    236 func (p *Provider) ProviderSources() []aws.CredentialSource {
    237 	if p.options.CredentialSources == nil {
    238 		return []aws.CredentialSource{aws.CredentialSourceIMDS}
    239 	} // If no source has been set, assume this is used directly which means just call to assume role
    240 	return p.options.CredentialSources
    241 }