Skip to content

Commit bc7f272

Browse files
ldgriswoldldgriswold
authored andcommitted
feat: add custom headers to taskrc and cli flags
1 parent c4ecff7 commit bc7f272

10 files changed

Lines changed: 781 additions & 9 deletions

File tree

executor.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ type (
4141
CACert string
4242
Cert string
4343
CertKey string
44+
Headers map[string]map[string]string
4445
Watch bool
4546
Verbose bool
4647
Silent bool
@@ -329,6 +330,19 @@ func (o *certKeyOption) ApplyToExecutor(e *Executor) {
329330
e.CertKey = o.certKey
330331
}
331332

333+
// WithHeaders sets the HTTP headers to use for remote requests, keyed by host.
334+
func WithHeaders(headers map[string]map[string]string) ExecutorOption {
335+
return &headersOption{headers: headers}
336+
}
337+
338+
type headersOption struct {
339+
headers map[string]map[string]string
340+
}
341+
342+
func (o *headersOption) ApplyToExecutor(e *Executor) {
343+
e.Headers = o.headers
344+
}
345+
332346
// WithWatch tells the [Executor] to keep running in the background and watch
333347
// for changes to the fingerprint of the tasks that are run. When changes are
334348
// detected, a new task run is triggered.

internal/flags/flags.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"os"
77
"path/filepath"
88
"strconv"
9+
"strings"
910
"time"
1011

1112
"github.com/fatih/color"
@@ -86,7 +87,11 @@ var (
8687
CACert string
8788
Cert string
8889
CertKey string
90+
RemoteHeaders []string
8991
Interactive bool
92+
93+
// Store the config for access in WithFlags()
94+
taskrcConfig *taskrcast.TaskRC
9095
)
9196

9297
func init() {
@@ -109,6 +114,7 @@ func init() {
109114
dir = cmp.Or(dir, filepath.Dir(entrypoint))
110115

111116
config, _ := taskrc.GetConfig(dir)
117+
taskrcConfig = config
112118
experiments.ParseWithConfig(dir, config)
113119

114120
// Parse the rest of the flags
@@ -174,6 +180,7 @@ func init() {
174180
pflag.StringVar(&CACert, "cacert", getConfig(config, "REMOTE_CACERT", func() *string { return config.Remote.CACert }, ""), "Path to a custom CA certificate for HTTPS connections.")
175181
pflag.StringVar(&Cert, "cert", getConfig(config, "REMOTE_CERT", func() *string { return config.Remote.Cert }, ""), "Path to a client certificate for HTTPS connections.")
176182
pflag.StringVar(&CertKey, "cert-key", getConfig(config, "REMOTE_CERT_KEY", func() *string { return config.Remote.CertKey }, ""), "Path to a client certificate key for HTTPS connections.")
183+
pflag.StringSliceVar(&RemoteHeaders, "header", nil, "HTTP header for remote requests in format 'host:Header-Name=value' (can be repeated).")
177184
}
178185
pflag.Parse()
179186

@@ -247,6 +254,13 @@ func Validate() error {
247254
return errors.New("task: --cert and --cert-key must be provided together")
248255
}
249256

257+
// Validate header flags format
258+
if len(RemoteHeaders) > 0 {
259+
if _, err := parseHeaderFlags(RemoteHeaders); err != nil {
260+
return err
261+
}
262+
}
263+
250264
return nil
251265
}
252266

@@ -277,6 +291,37 @@ func (o *flagsOption) ApplyToExecutor(e *task.Executor) {
277291
}
278292
}
279293

294+
// Merge headers from config and CLI (CLI takes precedence)
295+
headers := make(map[string]map[string]string)
296+
297+
// Start with config headers
298+
if taskrcConfig != nil && taskrcConfig.Remote.Headers != nil {
299+
for host, hostHeaders := range taskrcConfig.Remote.Headers {
300+
headers[host] = make(map[string]string)
301+
for key, value := range hostHeaders {
302+
headers[host][key] = value
303+
}
304+
}
305+
}
306+
307+
// Parse and merge CLI headers (these override config)
308+
if len(RemoteHeaders) > 0 {
309+
cliHeaders, err := parseHeaderFlags(RemoteHeaders)
310+
if err != nil {
311+
// This should have been caught in Validate(), but handle it anyway
312+
log.Printf("Warning: failed to parse --header flags: %v\n", err)
313+
} else {
314+
for host, hostHeaders := range cliHeaders {
315+
if headers[host] == nil {
316+
headers[host] = make(map[string]string)
317+
}
318+
for key, value := range hostHeaders {
319+
headers[host][key] = value
320+
}
321+
}
322+
}
323+
}
324+
280325
e.Options(
281326
task.WithDir(dir),
282327
task.WithEntrypoint(Entrypoint),
@@ -292,6 +337,7 @@ func (o *flagsOption) ApplyToExecutor(e *task.Executor) {
292337
task.WithCACert(CACert),
293338
task.WithCert(Cert),
294339
task.WithCertKey(CertKey),
340+
task.WithHeaders(headers),
295341
task.WithWatch(Watch),
296342
task.WithVerbose(Verbose),
297343
task.WithSilent(Silent),
@@ -353,3 +399,49 @@ func getEnvAs[T any](envKey string) (T, bool) {
353399
}
354400
return zero, false
355401
}
402+
403+
// parseHeaderFlags parses header flags in format "host:Header-Name=value" into a map structure.
404+
// Returns an error if any header is malformed.
405+
func parseHeaderFlags(headerFlags []string) (map[string]map[string]string, error) {
406+
if len(headerFlags) == 0 {
407+
return nil, nil
408+
}
409+
410+
headers := make(map[string]map[string]string)
411+
for _, header := range headerFlags {
412+
// Split on first colon to get host and rest
413+
colonIdx := strings.Index(header, ":")
414+
if colonIdx == -1 {
415+
return nil, errors.New("task: invalid --header format, expected 'host:Header-Name=value', got: " + header)
416+
}
417+
418+
host := strings.TrimSpace(header[:colonIdx])
419+
if host == "" {
420+
return nil, errors.New("task: invalid --header format, host cannot be empty: " + header)
421+
}
422+
423+
// Split on first equals to get header name and value
424+
rest := header[colonIdx+1:]
425+
equalIdx := strings.Index(rest, "=")
426+
if equalIdx == -1 {
427+
return nil, errors.New("task: invalid --header format, expected 'host:Header-Name=value', got: " + header)
428+
}
429+
430+
headerName := strings.TrimSpace(rest[:equalIdx])
431+
headerValue := rest[equalIdx+1:] // Don't trim value, might have intentional whitespace
432+
433+
if headerName == "" {
434+
return nil, errors.New("task: invalid --header format, header name cannot be empty: " + header)
435+
}
436+
437+
// Initialize map for this host if needed
438+
if headers[host] == nil {
439+
headers[host] = make(map[string]string)
440+
}
441+
442+
// Add header to map
443+
headers[host][headerName] = headerValue
444+
}
445+
446+
return headers, nil
447+
}

setup.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ func (e *Executor) getRootNode() (taskfile.Node, error) {
5858
taskfile.WithCACert(e.CACert),
5959
taskfile.WithCert(e.Cert),
6060
taskfile.WithCertKey(e.CertKey),
61+
taskfile.WithNodeHeaders(e.Headers),
6162
)
6263
var taskNotFoundError errors.TaskfileNotFoundError
6364
if errors.As(err, &taskNotFoundError) {
@@ -91,6 +92,7 @@ func (e *Executor) readTaskfile(node taskfile.Node) error {
9192
taskfile.WithReaderCACert(e.CACert),
9293
taskfile.WithReaderCert(e.Cert),
9394
taskfile.WithReaderCertKey(e.CertKey),
95+
taskfile.WithHeaders(e.Headers),
9496
taskfile.WithDebugFunc(debugFunc),
9597
taskfile.WithPromptFunc(promptFunc),
9698
)

0 commit comments

Comments
 (0)