diff --git a/internal/config/path.go b/internal/config/path.go index e8719df6b..8b8241032 100644 --- a/internal/config/path.go +++ b/internal/config/path.go @@ -1,6 +1,7 @@ package config import ( + "fmt" "os" "github.com/hashicorp/vagrant-plugin-sdk/helper/path" @@ -25,25 +26,37 @@ func GetVagrantfileName() string { // filename is empty, it will default to the Filename constant. func FindPath(dir path.Path, filename string) (p path.Path, err error) { if dir == nil { - cwd, err := os.Getwd() - if err != nil { - return nil, err + cwd, ok := os.LookupEnv("VAGRANT_CWD") + if ok { + if _, err := os.Stat(cwd); os.IsNotExist(err) { + return nil, fmt.Errorf("VAGRANT_CWD set to path (%s) that does not exist", cwd) + } else { + dir = path.NewPath(cwd) + } + } else { + cwd, err := os.Getwd() + if err != nil { + return nil, err + } + dir = path.NewPath(cwd) } - dir = path.NewPath(cwd) } if filename == "" { filename = GetVagrantfileName() } + p = dir for { - p = dir.Join(filename) + p = p.Join(filename) if _, err = os.Stat(p.String()); err == nil || !os.IsNotExist(err) { return } - if p.String() == p.Parent().String() { + // since we just tacked a filename on above, the first Parent() call is + // the directory of the file and the second is the actual parent dir + if p.Parent().String() == p.Parent().Parent().String() { return nil, nil } - p = p.Parent() + p = p.Parent().Parent() } } diff --git a/internal/config/path_test.go b/internal/config/path_test.go new file mode 100644 index 000000000..d7460cb16 --- /dev/null +++ b/internal/config/path_test.go @@ -0,0 +1,127 @@ +package config + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/hashicorp/vagrant-plugin-sdk/helper/path" + "github.com/stretchr/testify/require" +) + +func TestFindPath(t *testing.T) { + t.Run("uses dir and filename args if passed and file exists", func(t *testing.T) { + require := require.New(t) + + dir, err := ioutil.TempDir("", "test") + require.NoError(err) + defer os.RemoveAll(dir) + + p := filepath.Join(dir, "MyCoolFile") + file, err := os.Create(p) + require.NoError(err) + file.Close() + + out, err := FindPath(path.NewPath(dir), "MyCoolFile") + require.NoError(err) + require.Equal(p, out.String()) + }) + + t.Run("when VAGRANT_CWD is not set", func(t *testing.T) { + oldVcwd, ok := os.LookupEnv("VAGRANT_CWD") + if ok { + os.Unsetenv("VAGRANT_CWD") + defer os.Setenv("VAGRANT_CWD", oldVcwd) + } + + t.Run("uses cwd and Vagrantfile when blank args passed", func(t *testing.T) { + require := require.New(t) + + dir, err := ioutil.TempDir("", "test") + require.NoError(err) + defer os.RemoveAll(dir) + + p := filepath.Join(dir, "Vagrantfile") + file, err := os.Create(p) + require.NoError(err) + file.Close() + + oldCwd, err := os.Getwd() + require.NoError(err) + os.Chdir(dir) + defer os.Chdir(oldCwd) + + out, err := FindPath(nil, "") + require.NoError(err) + require.Equal(p, out.String()) + }) + + t.Run("walks parent dirs looking for Vagrantfile", func(t *testing.T) { + require := require.New(t) + + dir, err := ioutil.TempDir("", "test") + require.NoError(err) + defer os.RemoveAll(dir) + + deepPath := path.NewPath(filepath.Join(dir, "a", "b")) + err = os.MkdirAll(deepPath.String(), 0700) + require.NoError(err) + + notDeepFile := filepath.Join(dir, "Vagrantfile") + file, err := os.Create(notDeepFile) + require.NoError(err) + file.Close() + + out, err := FindPath(deepPath, "") + require.NoError(err) + require.Equal(notDeepFile, out.String()) + }) + + t.Run("returns nil if parent walk comes up empty", func(t *testing.T) { + require := require.New(t) + + dir, err := ioutil.TempDir("", "test") + require.NoError(err) + defer os.RemoveAll(dir) + + deepPath := path.NewPath(filepath.Join(dir, "a", "b")) + err = os.MkdirAll(deepPath.String(), 0700) + require.NoError(err) + + out, err := FindPath(deepPath, "") + require.NoError(err) + require.Nil(out) + }) + }) + + t.Run("honors VAGRANT_CWD if set and exists", func(t *testing.T) { + require := require.New(t) + + dir, err := ioutil.TempDir("", "test") + require.NoError(err) + defer os.RemoveAll(dir) + + os.Setenv("VAGRANT_CWD", dir) + defer os.Unsetenv("VAGRANT_CWD") + + file, err := os.Create(filepath.Join(dir, "Vagrantfile")) + require.NoError(err) + file.Close() + + out, err := FindPath(nil, "") + require.NoError(err) + require.Equal(filepath.Join(dir, "Vagrantfile"), out.String()) + }) + + t.Run("errors if VAGRANT_CWD is set and does not exist", func(t *testing.T) { + require := require.New(t) + + os.Setenv("VAGRANT_CWD", filepath.Join(os.TempDir(), "idontexit")) + defer os.Unsetenv("VAGRANT_CWD") + + _, err := FindPath(nil, "") + require.Error(err) + require.Contains(err.Error(), "does not exist") + }) +}