src

Go monorepo.
git clone git://code.dwrz.net/src
Log | Files | Refs

provider.go (6875B)


      1 package ec2rolecreds
      2 
      3 import (
      4 	"bufio"
      5 	"context"
      6 	"encoding/json"
      7 	"fmt"
      8 	"math"
      9 	"path"
     10 	"strings"
     11 	"time"
     12 
     13 	"github.com/aws/aws-sdk-go-v2/aws"
     14 	"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
     15 	sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
     16 	"github.com/aws/aws-sdk-go-v2/internal/sdk"
     17 	"github.com/aws/smithy-go"
     18 	"github.com/aws/smithy-go/logging"
     19 	"github.com/aws/smithy-go/middleware"
     20 )
     21 
     22 // ProviderName provides a name of EC2Role provider
     23 const ProviderName = "EC2RoleProvider"
     24 
     25 // GetMetadataAPIClient provides the interface for an EC2 IMDS API client for the
     26 // GetMetadata operation.
     27 type GetMetadataAPIClient interface {
     28 	GetMetadata(context.Context, *imds.GetMetadataInput, ...func(*imds.Options)) (*imds.GetMetadataOutput, error)
     29 }
     30 
     31 // A Provider retrieves credentials from the EC2 service, and keeps track if
     32 // those credentials are expired.
     33 //
     34 // The New function must be used to create the with a custom EC2 IMDS client.
     35 //
     36 //	p := &ec2rolecreds.New(func(o *ec2rolecreds.Options{
     37 //	     o.Client = imds.New(imds.Options{/* custom options */})
     38 //	})
     39 type Provider struct {
     40 	options Options
     41 }
     42 
     43 // Options is a list of user settable options for setting the behavior of the Provider.
     44 type Options struct {
     45 	// The API client that will be used by the provider to make GetMetadata API
     46 	// calls to EC2 IMDS.
     47 	//
     48 	// If nil, the provider will default to the EC2 IMDS client.
     49 	Client GetMetadataAPIClient
     50 }
     51 
     52 // New returns an initialized Provider value configured to retrieve
     53 // credentials from EC2 Instance Metadata service.
     54 func New(optFns ...func(*Options)) *Provider {
     55 	options := Options{}
     56 
     57 	for _, fn := range optFns {
     58 		fn(&options)
     59 	}
     60 
     61 	if options.Client == nil {
     62 		options.Client = imds.New(imds.Options{})
     63 	}
     64 
     65 	return &Provider{
     66 		options: options,
     67 	}
     68 }
     69 
     70 // Retrieve retrieves credentials from the EC2 service. Error will be returned
     71 // if the request fails, or unable to extract the desired credentials.
     72 func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
     73 	credsList, err := requestCredList(ctx, p.options.Client)
     74 	if err != nil {
     75 		return aws.Credentials{Source: ProviderName}, err
     76 	}
     77 
     78 	if len(credsList) == 0 {
     79 		return aws.Credentials{Source: ProviderName},
     80 			fmt.Errorf("unexpected empty EC2 IMDS role list")
     81 	}
     82 	credsName := credsList[0]
     83 
     84 	roleCreds, err := requestCred(ctx, p.options.Client, credsName)
     85 	if err != nil {
     86 		return aws.Credentials{Source: ProviderName}, err
     87 	}
     88 
     89 	creds := aws.Credentials{
     90 		AccessKeyID:     roleCreds.AccessKeyID,
     91 		SecretAccessKey: roleCreds.SecretAccessKey,
     92 		SessionToken:    roleCreds.Token,
     93 		Source:          ProviderName,
     94 
     95 		CanExpire: true,
     96 		Expires:   roleCreds.Expiration,
     97 	}
     98 
     99 	// Cap role credentials Expires to 1 hour so they can be refreshed more
    100 	// often. Jitter will be applied credentials cache if being used.
    101 	if anHour := sdk.NowTime().Add(1 * time.Hour); creds.Expires.After(anHour) {
    102 		creds.Expires = anHour
    103 	}
    104 
    105 	return creds, nil
    106 }
    107 
    108 // HandleFailToRefresh will extend the credentials Expires time if it it is
    109 // expired. If the credentials will not expire within the minimum time, they
    110 // will be returned.
    111 //
    112 // If the credentials cannot expire, the original error will be returned.
    113 func (p *Provider) HandleFailToRefresh(ctx context.Context, prevCreds aws.Credentials, err error) (
    114 	aws.Credentials, error,
    115 ) {
    116 	if !prevCreds.CanExpire {
    117 		return aws.Credentials{}, err
    118 	}
    119 
    120 	if prevCreds.Expires.After(sdk.NowTime().Add(5 * time.Minute)) {
    121 		return prevCreds, nil
    122 	}
    123 
    124 	newCreds := prevCreds
    125 	randFloat64, err := sdkrand.CryptoRandFloat64()
    126 	if err != nil {
    127 		return aws.Credentials{}, fmt.Errorf("failed to get random float, %w", err)
    128 	}
    129 
    130 	// Random distribution of [5,15) minutes.
    131 	expireOffset := time.Duration(randFloat64*float64(10*time.Minute)) + 5*time.Minute
    132 	newCreds.Expires = sdk.NowTime().Add(expireOffset)
    133 
    134 	logger := middleware.GetLogger(ctx)
    135 	logger.Logf(logging.Warn, "Attempting credential expiration extension due to a credential service availability issue. A refresh of these credentials will be attempted again in %v minutes.", math.Floor(expireOffset.Minutes()))
    136 
    137 	return newCreds, nil
    138 }
    139 
    140 // AdjustExpiresBy will adds the passed in duration to the passed in
    141 // credential's Expires time, unless the time until Expires is less than 15
    142 // minutes. Returns the credentials, even if not updated.
    143 func (p *Provider) AdjustExpiresBy(creds aws.Credentials, dur time.Duration) (
    144 	aws.Credentials, error,
    145 ) {
    146 	if !creds.CanExpire {
    147 		return creds, nil
    148 	}
    149 	if creds.Expires.Before(sdk.NowTime().Add(15 * time.Minute)) {
    150 		return creds, nil
    151 	}
    152 
    153 	creds.Expires = creds.Expires.Add(dur)
    154 	return creds, nil
    155 }
    156 
    157 // ec2RoleCredRespBody provides the shape for unmarshaling credential
    158 // request responses.
    159 type ec2RoleCredRespBody struct {
    160 	// Success State
    161 	Expiration      time.Time
    162 	AccessKeyID     string
    163 	SecretAccessKey string
    164 	Token           string
    165 
    166 	// Error state
    167 	Code    string
    168 	Message string
    169 }
    170 
    171 const iamSecurityCredsPath = "/iam/security-credentials/"
    172 
    173 // requestCredList requests a list of credentials from the EC2 service. If
    174 // there are no credentials, or there is an error making or receiving the
    175 // request
    176 func requestCredList(ctx context.Context, client GetMetadataAPIClient) ([]string, error) {
    177 	resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
    178 		Path: iamSecurityCredsPath,
    179 	})
    180 	if err != nil {
    181 		return nil, fmt.Errorf("no EC2 IMDS role found, %w", err)
    182 	}
    183 	defer resp.Content.Close()
    184 
    185 	credsList := []string{}
    186 	s := bufio.NewScanner(resp.Content)
    187 	for s.Scan() {
    188 		credsList = append(credsList, s.Text())
    189 	}
    190 
    191 	if err := s.Err(); err != nil {
    192 		return nil, fmt.Errorf("failed to read EC2 IMDS role, %w", err)
    193 	}
    194 
    195 	return credsList, nil
    196 }
    197 
    198 // requestCred requests the credentials for a specific credentials from the EC2 service.
    199 //
    200 // If the credentials cannot be found, or there is an error reading the response
    201 // and error will be returned.
    202 func requestCred(ctx context.Context, client GetMetadataAPIClient, credsName string) (ec2RoleCredRespBody, error) {
    203 	resp, err := client.GetMetadata(ctx, &imds.GetMetadataInput{
    204 		Path: path.Join(iamSecurityCredsPath, credsName),
    205 	})
    206 	if err != nil {
    207 		return ec2RoleCredRespBody{},
    208 			fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
    209 				credsName, err)
    210 	}
    211 	defer resp.Content.Close()
    212 
    213 	var respCreds ec2RoleCredRespBody
    214 	if err := json.NewDecoder(resp.Content).Decode(&respCreds); err != nil {
    215 		return ec2RoleCredRespBody{},
    216 			fmt.Errorf("failed to decode %s EC2 IMDS role credentials, %w",
    217 				credsName, err)
    218 	}
    219 
    220 	if !strings.EqualFold(respCreds.Code, "Success") {
    221 		// If an error code was returned something failed requesting the role.
    222 		return ec2RoleCredRespBody{},
    223 			fmt.Errorf("failed to get %s EC2 IMDS role credentials, %w",
    224 				credsName,
    225 				&smithy.GenericAPIError{Code: respCreds.Code, Message: respCreds.Message})
    226 	}
    227 
    228 	return respCreds, nil
    229 }