diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 19fb3a9f..724c0c6a 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -61,9 +61,12 @@ func main() { } if *configPath == "" { - fmt.Println("-config flag must be set") - flag.Usage() - os.Exit(1) + p, err := config.DefaultPath() + if err != nil { + fmt.Println(err) + os.Exit(1) + } + *configPath = p } c := config.NewC(l) diff --git a/cmd/nebula-service/service.go b/cmd/nebula-service/service.go index 6551ceb4..7c2b39c8 100644 --- a/cmd/nebula-service/service.go +++ b/cmd/nebula-service/service.go @@ -3,8 +3,6 @@ package main import ( "fmt" "log" - "os" - "path/filepath" "github.com/kardianos/service" "github.com/slackhq/nebula" @@ -57,24 +55,13 @@ func (p *program) Stop(s service.Service) error { return nil } -func fileExists(filename string) bool { - _, err := os.Stat(filename) - if os.IsNotExist(err) { - return false - } - return true -} - func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error { if *configPath == "" { - ex, err := os.Executable() + p, err := config.DefaultPath() if err != nil { return err } - *configPath = filepath.Dir(ex) + "/config.yaml" - if !fileExists(*configPath) { - *configPath = filepath.Dir(ex) + "/config.yml" - } + *configPath = p } svcConfig := &service.Config{ diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index d7f0de93..219519c2 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -50,9 +50,12 @@ func main() { } if *configPath == "" { - fmt.Println("-config flag must be set") - flag.Usage() - os.Exit(1) + p, err := config.DefaultPath() + if err != nil { + fmt.Println(err) + os.Exit(1) + } + *configPath = p } l := logging.NewLogger(os.Stdout) diff --git a/config/default.go b/config/default.go new file mode 100644 index 00000000..9494c655 --- /dev/null +++ b/config/default.go @@ -0,0 +1,29 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" +) + +// DefaultPath returns a path to a config file alongside the running executable, preferring config.yaml over config.yml. +// If neither file exists an error is returned that names both paths checked. +func DefaultPath() (string, error) { + ex, err := os.Executable() + if err != nil { + return "", err + } + return defaultPathInDir(filepath.Dir(ex)) +} + +func defaultPathInDir(dir string) (string, error) { + yamlPath := filepath.Join(dir, "config.yaml") + if _, err := os.Stat(yamlPath); err == nil { + return yamlPath, nil + } + ymlPath := filepath.Join(dir, "config.yml") + if _, err := os.Stat(ymlPath); err == nil { + return ymlPath, nil + } + return "", fmt.Errorf("no default config found at %s or %s", yamlPath, ymlPath) +} diff --git a/config/default_test.go b/config/default_test.go new file mode 100644 index 00000000..bb0a14d3 --- /dev/null +++ b/config/default_test.go @@ -0,0 +1,67 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDefaultPathInDir(t *testing.T) { + t.Run("prefers config.yaml when both exist", func(t *testing.T) { + dir := t.TempDir() + yaml := filepath.Join(dir, "config.yaml") + yml := filepath.Join(dir, "config.yml") + require.NoError(t, os.WriteFile(yaml, []byte("a: 1"), 0644)) + require.NoError(t, os.WriteFile(yml, []byte("a: 2"), 0644)) + + got, err := defaultPathInDir(dir) + require.NoError(t, err) + assert.Equal(t, yaml, got) + }) + + t.Run("returns config.yaml when only it exists", func(t *testing.T) { + dir := t.TempDir() + yaml := filepath.Join(dir, "config.yaml") + require.NoError(t, os.WriteFile(yaml, []byte("a: 1"), 0644)) + + got, err := defaultPathInDir(dir) + require.NoError(t, err) + assert.Equal(t, yaml, got) + }) + + t.Run("falls back to config.yml when only it exists", func(t *testing.T) { + dir := t.TempDir() + yml := filepath.Join(dir, "config.yml") + require.NoError(t, os.WriteFile(yml, []byte("a: 1"), 0644)) + + got, err := defaultPathInDir(dir) + require.NoError(t, err) + assert.Equal(t, yml, got) + }) + + t.Run("errors when neither exists and names both paths", func(t *testing.T) { + dir := t.TempDir() + got, err := defaultPathInDir(dir) + assert.Empty(t, got) + require.Error(t, err) + assert.Contains(t, err.Error(), filepath.Join(dir, "config.yaml")) + assert.Contains(t, err.Error(), filepath.Join(dir, "config.yml")) + }) +} + +func TestDefaultPath(t *testing.T) { + got, err := DefaultPath() + if err != nil { + ex, exErr := os.Executable() + require.NoError(t, exErr) + assert.Contains(t, err.Error(), filepath.Dir(ex)) + return + } + ex, err := os.Executable() + require.NoError(t, err) + assert.Equal(t, filepath.Dir(ex), filepath.Dir(got)) + assert.Contains(t, []string{"config.yaml", "config.yml"}, filepath.Base(got)) +}