src

Go monorepo.
git clone git://code.dwrz.net/src
Log | Files | Refs

sso_cached_token.go (5849B)


      1 package ssocreds
      2 
      3 import (
      4 	"crypto/sha1"
      5 	"encoding/hex"
      6 	"encoding/json"
      7 	"fmt"
      8 	"io/ioutil"
      9 	"os"
     10 	"path/filepath"
     11 	"strconv"
     12 	"strings"
     13 	"time"
     14 
     15 	"github.com/aws/aws-sdk-go-v2/internal/sdk"
     16 	"github.com/aws/aws-sdk-go-v2/internal/shareddefaults"
     17 )
     18 
     19 var osUserHomeDur = shareddefaults.UserHomeDir
     20 
     21 // StandardCachedTokenFilepath returns the filepath for the cached SSO token file, or
     22 // error if unable get derive the path. Key that will be used to compute a SHA1
     23 // value that is hex encoded.
     24 //
     25 // Derives the filepath using the Key as:
     26 //
     27 //	~/.aws/sso/cache/<sha1-hex-encoded-key>.json
     28 func StandardCachedTokenFilepath(key string) (string, error) {
     29 	homeDir := osUserHomeDur()
     30 	if len(homeDir) == 0 {
     31 		return "", fmt.Errorf("unable to get USER's home directory for cached token")
     32 	}
     33 	hash := sha1.New()
     34 	if _, err := hash.Write([]byte(key)); err != nil {
     35 		return "", fmt.Errorf("unable to compute cached token filepath key SHA1 hash, %w", err)
     36 	}
     37 
     38 	cacheFilename := strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json"
     39 
     40 	return filepath.Join(homeDir, ".aws", "sso", "cache", cacheFilename), nil
     41 }
     42 
     43 type tokenKnownFields struct {
     44 	AccessToken string   `json:"accessToken,omitempty"`
     45 	ExpiresAt   *rfc3339 `json:"expiresAt,omitempty"`
     46 
     47 	RefreshToken string `json:"refreshToken,omitempty"`
     48 	ClientID     string `json:"clientId,omitempty"`
     49 	ClientSecret string `json:"clientSecret,omitempty"`
     50 }
     51 
     52 type token struct {
     53 	tokenKnownFields
     54 	UnknownFields map[string]interface{} `json:"-"`
     55 }
     56 
     57 func (t token) MarshalJSON() ([]byte, error) {
     58 	fields := map[string]interface{}{}
     59 
     60 	setTokenFieldString(fields, "accessToken", t.AccessToken)
     61 	setTokenFieldRFC3339(fields, "expiresAt", t.ExpiresAt)
     62 
     63 	setTokenFieldString(fields, "refreshToken", t.RefreshToken)
     64 	setTokenFieldString(fields, "clientId", t.ClientID)
     65 	setTokenFieldString(fields, "clientSecret", t.ClientSecret)
     66 
     67 	for k, v := range t.UnknownFields {
     68 		if _, ok := fields[k]; ok {
     69 			return nil, fmt.Errorf("unknown token field %v, duplicates known field", k)
     70 		}
     71 		fields[k] = v
     72 	}
     73 
     74 	return json.Marshal(fields)
     75 }
     76 
     77 func setTokenFieldString(fields map[string]interface{}, key, value string) {
     78 	if value == "" {
     79 		return
     80 	}
     81 	fields[key] = value
     82 }
     83 func setTokenFieldRFC3339(fields map[string]interface{}, key string, value *rfc3339) {
     84 	if value == nil {
     85 		return
     86 	}
     87 	fields[key] = value
     88 }
     89 
     90 func (t *token) UnmarshalJSON(b []byte) error {
     91 	var fields map[string]interface{}
     92 	if err := json.Unmarshal(b, &fields); err != nil {
     93 		return nil
     94 	}
     95 
     96 	t.UnknownFields = map[string]interface{}{}
     97 
     98 	for k, v := range fields {
     99 		var err error
    100 		switch k {
    101 		case "accessToken":
    102 			err = getTokenFieldString(v, &t.AccessToken)
    103 		case "expiresAt":
    104 			err = getTokenFieldRFC3339(v, &t.ExpiresAt)
    105 		case "refreshToken":
    106 			err = getTokenFieldString(v, &t.RefreshToken)
    107 		case "clientId":
    108 			err = getTokenFieldString(v, &t.ClientID)
    109 		case "clientSecret":
    110 			err = getTokenFieldString(v, &t.ClientSecret)
    111 		default:
    112 			t.UnknownFields[k] = v
    113 		}
    114 
    115 		if err != nil {
    116 			return fmt.Errorf("field %q, %w", k, err)
    117 		}
    118 	}
    119 
    120 	return nil
    121 }
    122 
    123 func getTokenFieldString(v interface{}, value *string) error {
    124 	var ok bool
    125 	*value, ok = v.(string)
    126 	if !ok {
    127 		return fmt.Errorf("expect value to be string, got %T", v)
    128 	}
    129 	return nil
    130 }
    131 
    132 func getTokenFieldRFC3339(v interface{}, value **rfc3339) error {
    133 	var stringValue string
    134 	if err := getTokenFieldString(v, &stringValue); err != nil {
    135 		return err
    136 	}
    137 
    138 	timeValue, err := parseRFC3339(stringValue)
    139 	if err != nil {
    140 		return err
    141 	}
    142 
    143 	*value = &timeValue
    144 	return nil
    145 }
    146 
    147 func loadCachedToken(filename string) (token, error) {
    148 	fileBytes, err := ioutil.ReadFile(filename)
    149 	if err != nil {
    150 		return token{}, fmt.Errorf("failed to read cached SSO token file, %w", err)
    151 	}
    152 
    153 	var t token
    154 	if err := json.Unmarshal(fileBytes, &t); err != nil {
    155 		return token{}, fmt.Errorf("failed to parse cached SSO token file, %w", err)
    156 	}
    157 
    158 	if len(t.AccessToken) == 0 || t.ExpiresAt == nil || time.Time(*t.ExpiresAt).IsZero() {
    159 		return token{}, fmt.Errorf(
    160 			"cached SSO token must contain accessToken and expiresAt fields")
    161 	}
    162 
    163 	return t, nil
    164 }
    165 
    166 func storeCachedToken(filename string, t token, fileMode os.FileMode) (err error) {
    167 	tmpFilename := filename + ".tmp-" + strconv.FormatInt(sdk.NowTime().UnixNano(), 10)
    168 	if err := writeCacheFile(tmpFilename, fileMode, t); err != nil {
    169 		return err
    170 	}
    171 
    172 	if err := os.Rename(tmpFilename, filename); err != nil {
    173 		return fmt.Errorf("failed to replace old cached SSO token file, %w", err)
    174 	}
    175 
    176 	return nil
    177 }
    178 
    179 func writeCacheFile(filename string, fileMode os.FileMode, t token) (err error) {
    180 	var f *os.File
    181 	f, err = os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_RDWR, fileMode)
    182 	if err != nil {
    183 		return fmt.Errorf("failed to create cached SSO token file %w", err)
    184 	}
    185 
    186 	defer func() {
    187 		closeErr := f.Close()
    188 		if err == nil && closeErr != nil {
    189 			err = fmt.Errorf("failed to close cached SSO token file, %w", closeErr)
    190 		}
    191 	}()
    192 
    193 	encoder := json.NewEncoder(f)
    194 
    195 	if err = encoder.Encode(t); err != nil {
    196 		return fmt.Errorf("failed to serialize cached SSO token, %w", err)
    197 	}
    198 
    199 	return nil
    200 }
    201 
    202 type rfc3339 time.Time
    203 
    204 func parseRFC3339(v string) (rfc3339, error) {
    205 	parsed, err := time.Parse(time.RFC3339, v)
    206 	if err != nil {
    207 		return rfc3339{}, fmt.Errorf("expected RFC3339 timestamp: %w", err)
    208 	}
    209 
    210 	return rfc3339(parsed), nil
    211 }
    212 
    213 func (r *rfc3339) UnmarshalJSON(bytes []byte) (err error) {
    214 	var value string
    215 
    216 	// Use JSON unmarshal to unescape the quoted value making use of JSON's
    217 	// unquoting rules.
    218 	if err = json.Unmarshal(bytes, &value); err != nil {
    219 		return err
    220 	}
    221 
    222 	*r, err = parseRFC3339(value)
    223 
    224 	return nil
    225 }
    226 
    227 func (r *rfc3339) MarshalJSON() ([]byte, error) {
    228 	value := time.Time(*r).Format(time.RFC3339)
    229 
    230 	// Use JSON unmarshal to unescape the quoted value making use of JSON's
    231 	// quoting rules.
    232 	return json.Marshal(value)
    233 }