sso_cached_token.go (5849B)
1 package ssocreds 2 3 import ( 4 "crypto/sha1" 5 "encoding/hex" 6 "encoding/json" 7 "fmt" 8 "io/ioutil" 9 "os" 10 "path/filepath" 11 "strconv" 12 "strings" 13 "time" 14 15 "github.com/aws/aws-sdk-go-v2/internal/sdk" 16 "github.com/aws/aws-sdk-go-v2/internal/shareddefaults" 17 ) 18 19 var osUserHomeDur = shareddefaults.UserHomeDir 20 21 // StandardCachedTokenFilepath returns the filepath for the cached SSO token file, or 22 // error if unable get derive the path. Key that will be used to compute a SHA1 23 // value that is hex encoded. 24 // 25 // Derives the filepath using the Key as: 26 // 27 // ~/.aws/sso/cache/<sha1-hex-encoded-key>.json 28 func StandardCachedTokenFilepath(key string) (string, error) { 29 homeDir := osUserHomeDur() 30 if len(homeDir) == 0 { 31 return "", fmt.Errorf("unable to get USER's home directory for cached token") 32 } 33 hash := sha1.New() 34 if _, err := hash.Write([]byte(key)); err != nil { 35 return "", fmt.Errorf("unable to compute cached token filepath key SHA1 hash, %w", err) 36 } 37 38 cacheFilename := strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json" 39 40 return filepath.Join(homeDir, ".aws", "sso", "cache", cacheFilename), nil 41 } 42 43 type tokenKnownFields struct { 44 AccessToken string `json:"accessToken,omitempty"` 45 ExpiresAt *rfc3339 `json:"expiresAt,omitempty"` 46 47 RefreshToken string `json:"refreshToken,omitempty"` 48 ClientID string `json:"clientId,omitempty"` 49 ClientSecret string `json:"clientSecret,omitempty"` 50 } 51 52 type token struct { 53 tokenKnownFields 54 UnknownFields map[string]interface{} `json:"-"` 55 } 56 57 func (t token) MarshalJSON() ([]byte, error) { 58 fields := map[string]interface{}{} 59 60 setTokenFieldString(fields, "accessToken", t.AccessToken) 61 setTokenFieldRFC3339(fields, "expiresAt", t.ExpiresAt) 62 63 setTokenFieldString(fields, "refreshToken", t.RefreshToken) 64 setTokenFieldString(fields, "clientId", t.ClientID) 65 setTokenFieldString(fields, "clientSecret", t.ClientSecret) 66 67 for k, v := range t.UnknownFields { 68 if _, ok := fields[k]; ok { 69 return nil, fmt.Errorf("unknown token field %v, duplicates known field", k) 70 } 71 fields[k] = v 72 } 73 74 return json.Marshal(fields) 75 } 76 77 func setTokenFieldString(fields map[string]interface{}, key, value string) { 78 if value == "" { 79 return 80 } 81 fields[key] = value 82 } 83 func setTokenFieldRFC3339(fields map[string]interface{}, key string, value *rfc3339) { 84 if value == nil { 85 return 86 } 87 fields[key] = value 88 } 89 90 func (t *token) UnmarshalJSON(b []byte) error { 91 var fields map[string]interface{} 92 if err := json.Unmarshal(b, &fields); err != nil { 93 return nil 94 } 95 96 t.UnknownFields = map[string]interface{}{} 97 98 for k, v := range fields { 99 var err error 100 switch k { 101 case "accessToken": 102 err = getTokenFieldString(v, &t.AccessToken) 103 case "expiresAt": 104 err = getTokenFieldRFC3339(v, &t.ExpiresAt) 105 case "refreshToken": 106 err = getTokenFieldString(v, &t.RefreshToken) 107 case "clientId": 108 err = getTokenFieldString(v, &t.ClientID) 109 case "clientSecret": 110 err = getTokenFieldString(v, &t.ClientSecret) 111 default: 112 t.UnknownFields[k] = v 113 } 114 115 if err != nil { 116 return fmt.Errorf("field %q, %w", k, err) 117 } 118 } 119 120 return nil 121 } 122 123 func getTokenFieldString(v interface{}, value *string) error { 124 var ok bool 125 *value, ok = v.(string) 126 if !ok { 127 return fmt.Errorf("expect value to be string, got %T", v) 128 } 129 return nil 130 } 131 132 func getTokenFieldRFC3339(v interface{}, value **rfc3339) error { 133 var stringValue string 134 if err := getTokenFieldString(v, &stringValue); err != nil { 135 return err 136 } 137 138 timeValue, err := parseRFC3339(stringValue) 139 if err != nil { 140 return err 141 } 142 143 *value = &timeValue 144 return nil 145 } 146 147 func loadCachedToken(filename string) (token, error) { 148 fileBytes, err := ioutil.ReadFile(filename) 149 if err != nil { 150 return token{}, fmt.Errorf("failed to read cached SSO token file, %w", err) 151 } 152 153 var t token 154 if err := json.Unmarshal(fileBytes, &t); err != nil { 155 return token{}, fmt.Errorf("failed to parse cached SSO token file, %w", err) 156 } 157 158 if len(t.AccessToken) == 0 || t.ExpiresAt == nil || time.Time(*t.ExpiresAt).IsZero() { 159 return token{}, fmt.Errorf( 160 "cached SSO token must contain accessToken and expiresAt fields") 161 } 162 163 return t, nil 164 } 165 166 func storeCachedToken(filename string, t token, fileMode os.FileMode) (err error) { 167 tmpFilename := filename + ".tmp-" + strconv.FormatInt(sdk.NowTime().UnixNano(), 10) 168 if err := writeCacheFile(tmpFilename, fileMode, t); err != nil { 169 return err 170 } 171 172 if err := os.Rename(tmpFilename, filename); err != nil { 173 return fmt.Errorf("failed to replace old cached SSO token file, %w", err) 174 } 175 176 return nil 177 } 178 179 func writeCacheFile(filename string, fileMode os.FileMode, t token) (err error) { 180 var f *os.File 181 f, err = os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_RDWR, fileMode) 182 if err != nil { 183 return fmt.Errorf("failed to create cached SSO token file %w", err) 184 } 185 186 defer func() { 187 closeErr := f.Close() 188 if err == nil && closeErr != nil { 189 err = fmt.Errorf("failed to close cached SSO token file, %w", closeErr) 190 } 191 }() 192 193 encoder := json.NewEncoder(f) 194 195 if err = encoder.Encode(t); err != nil { 196 return fmt.Errorf("failed to serialize cached SSO token, %w", err) 197 } 198 199 return nil 200 } 201 202 type rfc3339 time.Time 203 204 func parseRFC3339(v string) (rfc3339, error) { 205 parsed, err := time.Parse(time.RFC3339, v) 206 if err != nil { 207 return rfc3339{}, fmt.Errorf("expected RFC3339 timestamp: %w", err) 208 } 209 210 return rfc3339(parsed), nil 211 } 212 213 func (r *rfc3339) UnmarshalJSON(bytes []byte) (err error) { 214 var value string 215 216 // Use JSON unmarshal to unescape the quoted value making use of JSON's 217 // unquoting rules. 218 if err = json.Unmarshal(bytes, &value); err != nil { 219 return err 220 } 221 222 *r, err = parseRFC3339(value) 223 224 return nil 225 } 226 227 func (r *rfc3339) MarshalJSON() ([]byte, error) { 228 value := time.Time(*r).Format(time.RFC3339) 229 230 // Use JSON unmarshal to unescape the quoted value making use of JSON's 231 // quoting rules. 232 return json.Marshal(value) 233 }