src

Go monorepo.
git clone git://code.dwrz.net/src
Log | Files | Refs

provider.go (7911B)


      1 package processcreds
      2 
      3 import (
      4 	"bytes"
      5 	"context"
      6 	"encoding/json"
      7 	"fmt"
      8 	"io"
      9 	"os"
     10 	"os/exec"
     11 	"runtime"
     12 	"time"
     13 
     14 	"github.com/aws/aws-sdk-go-v2/aws"
     15 	"github.com/aws/aws-sdk-go-v2/internal/sdkio"
     16 )
     17 
     18 const (
     19 	// ProviderName is the name this credentials provider will label any
     20 	// returned credentials Value with.
     21 	ProviderName = `ProcessProvider`
     22 
     23 	// DefaultTimeout default limit on time a process can run.
     24 	DefaultTimeout = time.Duration(1) * time.Minute
     25 )
     26 
     27 // ProviderError is an error indicating failure initializing or executing the
     28 // process credentials provider
     29 type ProviderError struct {
     30 	Err error
     31 }
     32 
     33 // Error returns the error message.
     34 func (e *ProviderError) Error() string {
     35 	return fmt.Sprintf("process provider error: %v", e.Err)
     36 }
     37 
     38 // Unwrap returns the underlying error the provider error wraps.
     39 func (e *ProviderError) Unwrap() error {
     40 	return e.Err
     41 }
     42 
     43 // Provider satisfies the credentials.Provider interface, and is a
     44 // client to retrieve credentials from a process.
     45 type Provider struct {
     46 	// Provides a constructor for exec.Cmd that are invoked by the provider for
     47 	// retrieving credentials. Use this to provide custom creation of exec.Cmd
     48 	// with things like environment variables, or other configuration.
     49 	//
     50 	// The provider defaults to the DefaultNewCommand function.
     51 	commandBuilder NewCommandBuilder
     52 
     53 	options Options
     54 }
     55 
     56 // Options is the configuration options for configuring the Provider.
     57 type Options struct {
     58 	// Timeout limits the time a process can run.
     59 	Timeout time.Duration
     60 }
     61 
     62 // NewCommandBuilder provides the interface for specifying how command will be
     63 // created that the Provider will use to retrieve credentials with.
     64 type NewCommandBuilder interface {
     65 	NewCommand(context.Context) (*exec.Cmd, error)
     66 }
     67 
     68 // NewCommandBuilderFunc provides a wrapper type around a function pointer to
     69 // satisfy the NewCommandBuilder interface.
     70 type NewCommandBuilderFunc func(context.Context) (*exec.Cmd, error)
     71 
     72 // NewCommand calls the underlying function pointer the builder was initialized with.
     73 func (fn NewCommandBuilderFunc) NewCommand(ctx context.Context) (*exec.Cmd, error) {
     74 	return fn(ctx)
     75 }
     76 
     77 // DefaultNewCommandBuilder provides the default NewCommandBuilder
     78 // implementation used by the provider. It takes a command and arguments to
     79 // invoke. The command will also be initialized with the current process
     80 // environment variables, stderr, and stdin pipes.
     81 type DefaultNewCommandBuilder struct {
     82 	Args []string
     83 }
     84 
     85 // NewCommand returns an initialized exec.Cmd with the builder's initialized
     86 // Args. The command is also initialized current process environment variables,
     87 // stderr, and stdin pipes.
     88 func (b DefaultNewCommandBuilder) NewCommand(ctx context.Context) (*exec.Cmd, error) {
     89 	var cmdArgs []string
     90 	if runtime.GOOS == "windows" {
     91 		cmdArgs = []string{"cmd.exe", "/C"}
     92 	} else {
     93 		cmdArgs = []string{"sh", "-c"}
     94 	}
     95 
     96 	if len(b.Args) == 0 {
     97 		return nil, &ProviderError{
     98 			Err: fmt.Errorf("failed to prepare command: command must not be empty"),
     99 		}
    100 	}
    101 
    102 	cmdArgs = append(cmdArgs, b.Args...)
    103 	cmd := exec.CommandContext(ctx, cmdArgs[0], cmdArgs[1:]...)
    104 	cmd.Env = os.Environ()
    105 
    106 	cmd.Stderr = os.Stderr // display stderr on console for MFA
    107 	cmd.Stdin = os.Stdin   // enable stdin for MFA
    108 
    109 	return cmd, nil
    110 }
    111 
    112 // NewProvider returns a pointer to a new Credentials object wrapping the
    113 // Provider.
    114 //
    115 // The provider defaults to the DefaultNewCommandBuilder for creating command
    116 // the Provider will use to retrieve credentials with.
    117 func NewProvider(command string, options ...func(*Options)) *Provider {
    118 	var args []string
    119 
    120 	// Ensure that the command arguments are not set if the provided command is
    121 	// empty. This will error out when the command is executed since no
    122 	// arguments are specified.
    123 	if len(command) > 0 {
    124 		args = []string{command}
    125 	}
    126 
    127 	commanBuilder := DefaultNewCommandBuilder{
    128 		Args: args,
    129 	}
    130 	return NewProviderCommand(commanBuilder, options...)
    131 }
    132 
    133 // NewProviderCommand returns a pointer to a new Credentials object with the
    134 // specified command, and default timeout duration. Use this to provide custom
    135 // creation of exec.Cmd for options like environment variables, or other
    136 // configuration.
    137 func NewProviderCommand(builder NewCommandBuilder, options ...func(*Options)) *Provider {
    138 	p := &Provider{
    139 		commandBuilder: builder,
    140 		options: Options{
    141 			Timeout: DefaultTimeout,
    142 		},
    143 	}
    144 
    145 	for _, option := range options {
    146 		option(&p.options)
    147 	}
    148 
    149 	return p
    150 }
    151 
    152 // A CredentialProcessResponse is the AWS credentials format that must be
    153 // returned when executing an external credential_process.
    154 type CredentialProcessResponse struct {
    155 	// As of this writing, the Version key must be set to 1. This might
    156 	// increment over time as the structure evolves.
    157 	Version int
    158 
    159 	// The access key ID that identifies the temporary security credentials.
    160 	AccessKeyID string `json:"AccessKeyId"`
    161 
    162 	// The secret access key that can be used to sign requests.
    163 	SecretAccessKey string
    164 
    165 	// The token that users must pass to the service API to use the temporary credentials.
    166 	SessionToken string
    167 
    168 	// The date on which the current credentials expire.
    169 	Expiration *time.Time
    170 }
    171 
    172 // Retrieve executes the credential process command and returns the
    173 // credentials, or error if the command fails.
    174 func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
    175 	out, err := p.executeCredentialProcess(ctx)
    176 	if err != nil {
    177 		return aws.Credentials{Source: ProviderName}, err
    178 	}
    179 
    180 	// Serialize and validate response
    181 	resp := &CredentialProcessResponse{}
    182 	if err = json.Unmarshal(out, resp); err != nil {
    183 		return aws.Credentials{Source: ProviderName}, &ProviderError{
    184 			Err: fmt.Errorf("parse failed of process output: %s, error: %w", out, err),
    185 		}
    186 	}
    187 
    188 	if resp.Version != 1 {
    189 		return aws.Credentials{Source: ProviderName}, &ProviderError{
    190 			Err: fmt.Errorf("wrong version in process output (not 1)"),
    191 		}
    192 	}
    193 
    194 	if len(resp.AccessKeyID) == 0 {
    195 		return aws.Credentials{Source: ProviderName}, &ProviderError{
    196 			Err: fmt.Errorf("missing AccessKeyId in process output"),
    197 		}
    198 	}
    199 
    200 	if len(resp.SecretAccessKey) == 0 {
    201 		return aws.Credentials{Source: ProviderName}, &ProviderError{
    202 			Err: fmt.Errorf("missing SecretAccessKey in process output"),
    203 		}
    204 	}
    205 
    206 	creds := aws.Credentials{
    207 		Source:          ProviderName,
    208 		AccessKeyID:     resp.AccessKeyID,
    209 		SecretAccessKey: resp.SecretAccessKey,
    210 		SessionToken:    resp.SessionToken,
    211 	}
    212 
    213 	// Handle expiration
    214 	if resp.Expiration != nil {
    215 		creds.CanExpire = true
    216 		creds.Expires = *resp.Expiration
    217 	}
    218 
    219 	return creds, nil
    220 }
    221 
    222 // executeCredentialProcess starts the credential process on the OS and
    223 // returns the results or an error.
    224 func (p *Provider) executeCredentialProcess(ctx context.Context) ([]byte, error) {
    225 	if p.options.Timeout >= 0 {
    226 		var cancelFunc func()
    227 		ctx, cancelFunc = context.WithTimeout(ctx, p.options.Timeout)
    228 		defer cancelFunc()
    229 	}
    230 
    231 	cmd, err := p.commandBuilder.NewCommand(ctx)
    232 	if err != nil {
    233 		return nil, err
    234 	}
    235 
    236 	// get creds json on process's stdout
    237 	output := bytes.NewBuffer(make([]byte, 0, int(8*sdkio.KibiByte)))
    238 	if cmd.Stdout != nil {
    239 		cmd.Stdout = io.MultiWriter(cmd.Stdout, output)
    240 	} else {
    241 		cmd.Stdout = output
    242 	}
    243 
    244 	execCh := make(chan error, 1)
    245 	go executeCommand(cmd, execCh)
    246 
    247 	select {
    248 	case execError := <-execCh:
    249 		if execError == nil {
    250 			break
    251 		}
    252 		select {
    253 		case <-ctx.Done():
    254 			return output.Bytes(), &ProviderError{
    255 				Err: fmt.Errorf("credential process timed out: %w", execError),
    256 			}
    257 		default:
    258 			return output.Bytes(), &ProviderError{
    259 				Err: fmt.Errorf("error in credential_process: %w", execError),
    260 			}
    261 		}
    262 	}
    263 
    264 	out := output.Bytes()
    265 	if runtime.GOOS == "windows" {
    266 		// windows adds slashes to quotes
    267 		out = bytes.ReplaceAll(out, []byte(`\"`), []byte(`"`))
    268 	}
    269 
    270 	return out, nil
    271 }
    272 
    273 func executeCommand(cmd *exec.Cmd, exec chan error) {
    274 	// Start the command
    275 	err := cmd.Start()
    276 	if err == nil {
    277 		err = cmd.Wait()
    278 	}
    279 
    280 	exec <- err
    281 }