diff --git a/internal/server/singleprocess/state/basis.go b/internal/server/singleprocess/state/basis.go index 424f5655f..3141567fa 100644 --- a/internal/server/singleprocess/state/basis.go +++ b/internal/server/singleprocess/state/basis.go @@ -8,6 +8,7 @@ import ( "github.com/hashicorp/vagrant/internal/server" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" "gorm.io/gorm" + "gorm.io/gorm/clause" ) func init() { @@ -23,16 +24,32 @@ type scope interface { type Basis struct { Model - Vagrantfile *Vagrantfile `mapstructure:"Configuration" gorm:"OnDelete:Cascade"` + Vagrantfile *Vagrantfile `gorm:"constraint:OnDelete:SET NULL" mapstructure:"Configuration"` VagrantfileID *uint `mapstructure:"-"` DataSource *ProtoValue Jobs []*InternalJob `gorm:"polymorphic:Scope" mapstructure:"-"` Metadata MetadataSet - Name *string `gorm:"uniqueIndex,not null"` - Path *string `gorm:"uniqueIndex,not null"` - Projects []*Project + Name string `gorm:"uniqueIndex,not null"` + Path string `gorm:"uniqueIndex,not null"` + Projects []*Project `gorm:"constraint:OnDelete:SET NULL"` RemoteEnabled bool - ResourceId *string `gorm:"<-:create,uniqueIndex,not null"` + ResourceId string `gorm:"uniqueIndex,not null"` // TODO(spox): readonly permission not working as expected +} + +// Returns a fully populated instance of the current basis +func (b *Basis) find(db *gorm.DB) (*Basis, error) { + var basis Basis + result := db.Preload(clause.Associations). + Where(&Basis{ResourceId: b.ResourceId}). + Or(&Basis{Name: b.Name}). + Or(&Basis{Path: b.Path}). + Or(&Basis{Model: Model{ID: b.ID}}). + First(&basis) + if result.Error != nil { + return nil, result.Error + } + + return &basis, nil } func (b *Basis) scope() interface{} { @@ -44,8 +61,40 @@ func (Basis) TableName() string { return "basis" } +// Use before delete hook to remove all assocations +func (b *Basis) BeforeDelete(tx *gorm.DB) error { + basis, err := b.find(tx) + if err != nil { + return err + } + + // If Vagrantfile is attached, delete it + if basis.VagrantfileID != nil { + result := tx.Where(&Vagrantfile{Model: Model{ID: *basis.VagrantfileID}}). + Delete(&Vagrantfile{}) + if result.Error != nil { + return result.Error + } + } + + if len(basis.Projects) > 0 { + if result := tx.Delete(basis.Projects); result.Error != nil { + return result.Error + } + } + + if len(basis.Jobs) > 0 { + result := tx.Delete(basis.Jobs) + if result.Error != nil { + return result.Error + } + } + + return nil +} + func (b *Basis) BeforeSave(tx *gorm.DB) error { - if b.ResourceId == nil { + if b.ResourceId == "" { if err := b.setId(); err != nil { return err } @@ -57,8 +106,44 @@ func (b *Basis) BeforeSave(tx *gorm.DB) error { return nil } +func (b *Basis) BeforeUpdate(tx *gorm.DB) error { + // If a Vagrantfile was already set for the basis, just update it + if b.Vagrantfile != nil && b.Vagrantfile.ID == 0 && b.VagrantfileID != nil { + var v Vagrantfile + result := tx.First(&v, &Vagrantfile{Model: Model{ID: *b.VagrantfileID}}) + if result.Error != nil { + return result.Error + } + id := v.ID + if err := decode(b.Vagrantfile, &v); err != nil { + return err + } + v.ID = id + b.Vagrantfile = &v + + // NOTE: Just updating the value doesn't save the changes so + // save the changes in this transaction + if result := tx.Save(&v); result.Error != nil { + return result.Error + } + } + return nil +} + func (b *Basis) Validate(tx *gorm.DB) error { - err := validation.ValidateStruct(b, + // NOTE: We should be able to use `tx.Statement.Changed("ResourceId")` + // for change detection but it doesn't appear to be set correctly + // so we don't get any notice of change (maybe because it's a pointer?) + existing, err := b.find(tx) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + if existing == nil { + existing = &Basis{} + } + + err = validation.ValidateStruct(b, validation.Field(&b.Name, validation.Required, validation.By( @@ -88,6 +173,14 @@ func (b *Basis) Validate(tx *gorm.DB) error { Not(&Basis{Model: Model{ID: b.ID}}), ), ), + validation.When( + b.ID != 0, + validation.By( + checkNotModified( + existing.ResourceId, + ), + ), + ), ), ) @@ -103,7 +196,7 @@ func (b *Basis) setId() error { if err != nil { return err } - b.ResourceId = &id + b.ResourceId = id return nil } @@ -199,7 +292,7 @@ func (s *State) BasisFromProtoRef( } var basis Basis - result := s.search().First(&basis, &Basis{ResourceId: &ref.ResourceId}) + result := s.search().First(&basis, &Basis{ResourceId: ref.ResourceId}) if result.Error != nil { return nil, result.Error } @@ -228,10 +321,10 @@ func (s *State) BasisFromProtoRefFuzzy( query := &Basis{} if ref.Name != "" { - query.Name = &ref.Name + query.Name = ref.Name } if ref.Path != "" { - query.Path = &ref.Path + query.Path = ref.Path } result := s.search().First(basis, query) diff --git a/internal/server/singleprocess/state/basis_test.go b/internal/server/singleprocess/state/basis_test.go index 742949fbf..608bbc940 100644 --- a/internal/server/singleprocess/state/basis_test.go +++ b/internal/server/singleprocess/state/basis_test.go @@ -6,9 +6,292 @@ import ( "github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" "github.com/stretchr/testify/require" + "gorm.io/gorm" ) -func TestBasis(t *testing.T) { +func TestBasis_Create(t *testing.T) { + t.Run("Requires name and path", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Basis{}) + require.Error(result.Error) + require.ErrorContains(result.Error, "Name:") + require.ErrorContains(result.Error, "Path:") + }) + + t.Run("Requires name", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Basis{Path: "/dev/null"}) + require.Error(result.Error) + require.ErrorContains(result.Error, "Name:") + }) + + t.Run("Requires path", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Basis{Name: "default"}) + require.Error(result.Error) + require.ErrorContains(result.Error, "Path:") + }) + + t.Run("Sets resource ID", func(t *testing.T) { + require, db := requireAndDB(t) + + basis := Basis{Name: "default", Path: "/dev/null"} + result := db.Save(&basis) + require.NoError(result.Error) + require.NotEmpty(basis.ResourceId) + }) + + t.Run("Retains resource ID", func(t *testing.T) { + require, db := requireAndDB(t) + + rid := "RESOURCE_ID" + basis := Basis{Name: "default", Path: "/dev/null", ResourceId: rid} + result := db.Save(&basis) + require.NoError(result.Error) + require.EqualValues(rid, basis.ResourceId) + }) + + t.Run("Does not allow duplicate name", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Basis{Name: "default", Path: "/dev/null"}) + require.NoError(result.Error) + result = db.Save(&Basis{Name: "default", Path: "/dev/null/other"}) + require.Error(result.Error) + require.ErrorContains(result.Error, "Name:") + }) + + t.Run("Does not allow duplicate path", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Basis{Name: "default", Path: "/dev/null"}) + require.NoError(result.Error) + result = db.Save(&Basis{Name: "other", Path: "/dev/null"}) + require.Error(result.Error) + require.ErrorContains(result.Error, "Path:") + }) + + t.Run("Does not allow duplicate resource IDs", func(t *testing.T) { + require, db := requireAndDB(t) + + rid := "RESOURCE ID" + result := db.Save(&Basis{Name: "default", Path: "/dev/null", ResourceId: rid}) + require.NoError(result.Error) + result = db.Save(&Basis{Name: "other", Path: "/dev/null/other", ResourceId: rid}) + require.Error(result.Error) + require.ErrorContains(result.Error, "ResourceId:") + }) + + t.Run("Creates Vagrantfile when set", func(t *testing.T) { + require, db := requireAndDB(t) + + vagrantfile := Vagrantfile{} + basis := Basis{ + Name: "default", + Path: "/dev/null", + Vagrantfile: &vagrantfile, + } + result := db.Save(&basis) + require.NoError(result.Error) + require.NotNil(basis.VagrantfileID) + require.Equal(*basis.VagrantfileID, vagrantfile.ID) + }) +} + +func TestBasis_Update(t *testing.T) { + t.Run("Requires name and path", func(t *testing.T) { + require, db := requireAndDB(t) + + basis := &Basis{Name: "default", Path: "/dev/null"} + result := db.Save(basis) + require.NoError(result.Error) + + basis.Name = "" + basis.Path = "" + result = db.Save(basis) + require.Error(result.Error) + require.ErrorContains(result.Error, "Name:") + require.ErrorContains(result.Error, "Path:") + }) + + t.Run("Requires name", func(t *testing.T) { + require, db := requireAndDB(t) + + basis := &Basis{Name: "default", Path: "/dev/null"} + result := db.Save(basis) + require.NoError(result.Error) + basis.Name = "" + result = db.Save(basis) + require.Error(result.Error) + require.ErrorContains(result.Error, "Name:") + }) + + t.Run("Requires path", func(t *testing.T) { + require, db := requireAndDB(t) + + basis := &Basis{Name: "default", Path: "/dev/null"} + result := db.Save(basis) + require.NoError(result.Error) + basis.Path = "" + result = db.Save(basis) + require.Error(result.Error) + require.ErrorContains(result.Error, "Path:") + }) + + t.Run("Does not update resource ID", func(t *testing.T) { + require, db := requireAndDB(t) + + basis := Basis{Name: "default", Path: "/dev/null"} + result := db.Save(&basis) + require.NoError(result.Error) + require.NotEmpty(basis.ResourceId) + + var reloadBasis Basis + result = db.First(&reloadBasis, &Basis{Model: Model{ID: basis.ID}}) + require.NoError(result.Error) + + reloadBasis.ResourceId = "NEW VALUE" + result = db.Save(&reloadBasis) + require.Error(result.Error) + require.ErrorContains(result.Error, "ResourceId:") + }) + + t.Run("Adds Vagrantfile", func(t *testing.T) { + require, db := requireAndDB(t) + + vpath := "/dev/null/Vagrantfile" + basis := Basis{Name: "default", Path: "/dev/null"} + result := db.Save(&basis) + require.NoError(result.Error) + v := &Vagrantfile{Path: &vpath} + basis.Vagrantfile = v + result = db.Save(&basis) + require.NoError(result.Error) + require.NotEmpty(v.ID) + }) + + t.Run("Updates existing Vagrantfile content", func(t *testing.T) { + require, db := requireAndDB(t) + + // Create inital basis + vpath := "/dev/null/Vagrantfile" + v := &Vagrantfile{Path: &vpath} + basis := Basis{Name: "default", Path: "/dev/null", Vagrantfile: v} + result := db.Save(&basis) + require.NoError(result.Error) + require.NotEmpty(v.ID) + originalID := v.ID + + // Update with new Vagrantfile + newPath := "/dev/null/new" + newV := &Vagrantfile{Path: &newPath} + basis.Vagrantfile = newV + result = db.Save(&basis) + require.NoError(result.Error) + require.Equal(*basis.Vagrantfile.Path, newPath) + require.Equal(originalID, basis.Vagrantfile.ID) + + // Refetch Vagrantfile to ensure persisted changes + var checkVF Vagrantfile + result = db.First(&checkVF, &Vagrantfile{Model: Model{ID: originalID}}) + require.NoError(result.Error) + require.Equal(*checkVF.Path, newPath) + + // Validate only one Vagrantfile has been stored + var count int64 + result = db.Model(&Vagrantfile{}).Count(&count) + require.NoError(result.Error) + require.Equal(int64(1), count) + }) +} + +func TestBasis_Delete(t *testing.T) { + t.Run("Deletes basis", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Basis{Name: "default", Path: "/dev/null"}) + require.NoError(result.Error) + + var basis Basis + result = db.First(&basis, &Basis{Name: "default"}) + require.NoError(result.Error) + + result = db.Where(&Basis{ResourceId: basis.ResourceId}). + Delete(&Basis{}) + require.NoError(result.Error) + result = db.First(&Basis{}, &Basis{ResourceId: basis.ResourceId}) + require.Error(result.Error) + require.ErrorIs(result.Error, gorm.ErrRecordNotFound) + }) + + t.Run("Deletes Vagrantfile", func(t *testing.T) { + require, db := requireAndDB(t) + + vpath := "/dev/null/Vagrantfile" + result := db.Save(&Basis{ + Name: "default", + Path: "/dev/null", + Vagrantfile: &Vagrantfile{Path: &vpath}, + }) + require.NoError(result.Error) + + var count int64 + result = db.Model(&Vagrantfile{}).Count(&count) + require.NoError(result.Error) + require.Equal(int64(1), count) + + result = db.Where(&Basis{Name: "default"}). + Delete(&Basis{}) + require.NoError(result.Error) + result = db.Model((*Vagrantfile)(nil)).Count(&count) + require.NoError(result.Error) + require.Equal(int64(0), count) + }) + + t.Run("Deletes Projects", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Basis{ + Name: "default", + Path: "/dev/null", + Projects: []*Project{ + { + Name: "default", + Path: "/dev/null/default", + }, + { + Name: "Other", + Path: "/dev/null/other", + }, + }, + }) + require.NoError(result.Error) + + var count int64 + result = db.Model(&Basis{}).Count(&count) + require.NoError(result.Error) + require.Equal(int64(1), count) + result = db.Model(&Project{}).Count(&count) + require.NoError(result.Error) + require.Equal(int64(2), count) + + result = db.Where(&Basis{Name: "default"}). + Delete(&Basis{}) + require.NoError(result.Error) + + result = db.Model(&Basis{}).Count(&count) + require.NoError(result.Error) + require.Equal(int64(0), count) + result = db.Model(&Project{}).Count(&count) + require.NoError(result.Error) + require.Equal(int64(0), count) + }) +} + +func TestBasis_State(t *testing.T) { t.Run("Get returns error if not exist", func(t *testing.T) { require := require.New(t) diff --git a/internal/server/singleprocess/state/box.go b/internal/server/singleprocess/state/box.go index 5be6b8642..74144037c 100644 --- a/internal/server/singleprocess/state/box.go +++ b/internal/server/singleprocess/state/box.go @@ -6,7 +6,6 @@ import ( "time" "github.com/go-ozzo/ozzo-validation/v4" - "github.com/go-ozzo/ozzo-validation/v4/is" "github.com/hashicorp/go-version" "github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk" "github.com/hashicorp/vagrant/internal/server" @@ -27,27 +26,39 @@ const ( type Box struct { Model - Directory *string `gorm:"not null"` - LastUpdate *time.Time `gorm:"autoUpdateTime"` + Directory string `gorm:"uniqueIndex"` + LastUpdate time.Time `gorm:"autoUpdateTime"` Metadata *ProtoValue MetadataUrl *string - Name *string `gorm:"uniqueIndex:idx_nameverprov;not null"` - Provider *string `gorm:"uniqueIndex:idx_nameverprov;not null"` - ResourceId *string `gorm:"<-:create;uniqueIndex;not null"` - Version *string `gorm:"uniqueIndex:idx_nameverprov;not null"` + Name string `gorm:"uniqueIndex:idx_nameverprov"` + Provider string `gorm:"uniqueIndex:idx_nameverprov"` + ResourceId string `gorm:"uniqueIndex"` + Version string `gorm:"uniqueIndex:idx_nameverprov"` +} + +func (b *Box) find(db *gorm.DB) (*Box, error) { + var box Box + result := db.Where(&Box{ResourceId: b.ResourceId}). + Or(&Box{Directory: b.Directory}). + Or(&Box{Name: b.Name, Provider: b.Provider, Version: b.Version}). + First(&box) + if result.Error != nil { + return nil, result.Error + } + + return &box, nil } func (b *Box) BeforeSave(tx *gorm.DB) error { - if b.ResourceId == nil { + if b.ResourceId == "" { if err := b.setId(); err != nil { return err } } // If version is not set, default it to 0 - if b.Version == nil || *b.Version == "0" { - v := DEFAULT_BOX_VERSION - b.Version = &v + if b.Version == "" || b.Version == "0" { + b.Version = DEFAULT_BOX_VERSION } if err := b.Validate(tx); err != nil { @@ -62,14 +73,32 @@ func (b *Box) setId() error { if err != nil { return err } - b.ResourceId = &id + b.ResourceId = id return nil } func (b *Box) Validate(tx *gorm.DB) error { - err := validation.ValidateStruct(b, - validation.Field(&b.Directory, validation.Required), + existing, err := b.find(tx) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + if existing == nil { + existing = &Box{} + } + + err = validation.ValidateStruct(b, + validation.Field(&b.Directory, + validation.Required, + validation.By( + checkUnique( + tx.Model((*Box)(nil)). + Where(&Box{Directory: b.Directory}). + Not(&Box{Model: Model{ID: b.ID}}), + ), + ), + ), validation.Field(&b.Name, validation.Required), validation.Field(&b.Provider, validation.Required), validation.Field(&b.ResourceId, @@ -81,10 +110,18 @@ func (b *Box) Validate(tx *gorm.DB) error { Not(&Box{Model: Model{ID: b.ID}}), ), ), + validation.When( + b.ID != 0, + validation.By( + checkNotModified( + existing.ResourceId, + ), + ), + ), ), validation.Field(&b.Version, validation.Required, - is.Semver, + validation.By(checkValidVersion), ), ) @@ -141,7 +178,7 @@ func (s *State) BoxFromProtoRef( } var box Box - result := s.search().First(&box, &Box{ResourceId: &b.ResourceId}) + result := s.search().First(&box, &Box{ResourceId: b.ResourceId}) if result.Error != nil { return nil, result.Error } @@ -168,9 +205,9 @@ func (s *State) BoxFromProtoRefFuzzy( box = &Box{} result := s.search().First(box, &Box{ - Name: &b.Name, - Provider: &b.Provider, - Version: &b.Version, + Name: b.Name, + Provider: b.Provider, + Version: b.Version, }, ) if result.Error != nil { @@ -315,8 +352,8 @@ func (s *State) BoxFind( var boxes []Box result := s.search().Find(&boxes, &Box{ - Name: &b.Name, - Provider: &b.Provider, + Name: b.Name, + Provider: b.Provider, }, ) if result.Error != nil { @@ -326,9 +363,9 @@ func (s *State) BoxFind( return nil, lookupErrorToStatus("box", result.Error) } - // If we found no boxes, return a not found error + // If we found no boxes, return nil result with no error if len(boxes) < 1 { - return nil, nil // lookupErrorToStatus("box", gorm.ErrRecordNotFound) + return nil, nil } // If we have no version value set, apply the default @@ -345,7 +382,7 @@ func (s *State) BoxFind( } for _, box := range boxes { - boxVersion, err := version.NewVersion(*box.Version) + boxVersion, err := version.NewVersion(box.Version) if err != nil { return nil, lookupErrorToStatus("box", err) } @@ -362,5 +399,7 @@ func (s *State) BoxFind( return match.ToProto(), nil } - return nil, nil // lookupErrorToStatus("box", gorm.ErrRecordNotFound) + // If nothing was found, return a nil + // result and no error + return nil, nil } diff --git a/internal/server/singleprocess/state/box_test.go b/internal/server/singleprocess/state/box_test.go index 2dee5d325..18f8e653e 100644 --- a/internal/server/singleprocess/state/box_test.go +++ b/internal/server/singleprocess/state/box_test.go @@ -8,7 +8,177 @@ import ( "github.com/stretchr/testify/require" ) -func TestBox(t *testing.T) { +func TestBox_Create(t *testing.T) { + t.Run("Requires directory, name, provider", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Box{}) + require.Error(result.Error) + require.ErrorContains(result.Error, "Directory:") + require.ErrorContains(result.Error, "Name:") + require.ErrorContains(result.Error, "Provider:") + }) + + t.Run("Requires directory", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Box{Name: "default", Provider: "virt"}) + require.Error(result.Error) + require.ErrorContains(result.Error, "Directory:") + }) + + t.Run("Requires name", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Box{Provider: "virt", Directory: "/dev/null"}) + require.Error(result.Error) + require.ErrorContains(result.Error, "Name:") + }) + + t.Run("Requires provider", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Box{Name: "default", Directory: "/dev/null"}) + require.Error(result.Error) + require.ErrorContains(result.Error, "Provider:") + }) + + t.Run("Sets the ResourceId", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Box{ + Name: "default", + Directory: "/dev/null", + Provider: "virt", + }) + require.NoError(result.Error) + var box Box + result = db.First(&box, &Box{Name: "default"}) + require.NoError(result.Error) + require.NotEmpty(box.ResourceId) + }) + + t.Run("Defaults version when not set", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Box{ + Name: "default", + Directory: "/dev/null", + Provider: "virt", + }) + require.NoError(result.Error) + var box Box + result = db.First(&box, &Box{Name: "default"}) + require.NoError(result.Error) + require.Equal(DEFAULT_BOX_VERSION, box.Version) + }) + + t.Run("Defaults version when set to 0", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Box{ + Name: "default", + Directory: "/dev/null", + Provider: "virt", + Version: "0", + }) + require.NoError(result.Error) + var box Box + result = db.First(&box, &Box{Name: "default"}) + require.NoError(result.Error) + require.Equal(DEFAULT_BOX_VERSION, box.Version) + }) + + t.Run("Requires version to be semver", func(t *testing.T) { + require, db := requireAndDB(t) + + box := &Box{ + Name: "default", + Directory: "/dev/null", + Provider: "virt", + Version: "0.a", + } + + result := db.Save(box) + require.Error(result.Error) + require.ErrorContains(result.Error, "Version:") + + box.Version = "string" + result = db.Save(box) + require.Error(result.Error) + require.ErrorContains(result.Error, "Version:") + + box.Version = "a0.1.2" + result = db.Save(box) + require.Error(result.Error) + require.ErrorContains(result.Error, "Version:") + }) + + t.Run("Does not allow duplicates", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Box{ + Name: "default", + Directory: "/dev/null", + Provider: "virt", + Version: "1.0.0", + }) + require.NoError(result.Error) + + result = db.Save(&Box{ + Name: "default", + Directory: "/dev/null/other", + Provider: "virt", + Version: "1.0.0", + }) + require.Error(result.Error) + require.ErrorContains(result.Error, "name") + require.ErrorContains(result.Error, "provider") + require.ErrorContains(result.Error, "version") + }) + + t.Run("Allows multiple versions", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Box{ + Name: "default", + Directory: "/dev/null", + Provider: "virt", + Version: "1.0.0", + }) + require.NoError(result.Error) + + result = db.Save(&Box{ + Name: "default", + Directory: "/dev/null/other", + Provider: "virt", + Version: "1.0.1", + }) + require.NoError(result.Error) + }) + + t.Run("Allows multiple providers", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Box{ + Name: "default", + Directory: "/dev/null", + Provider: "virt", + Version: "1.0.0", + }) + require.NoError(result.Error) + + result = db.Save(&Box{ + Name: "default", + Directory: "/dev/null/other", + Provider: "virtz", + Version: "1.0.0", + }) + require.NoError(result.Error) + }) +} + +func TestBox_State(t *testing.T) { t.Run("Get returns error if not exist", func(t *testing.T) { require := require.New(t) @@ -200,7 +370,7 @@ func TestBox(t *testing.T) { Version: "1.2.3", Provider: "dontexist", }) - require.Error(err) + require.NoError(err) require.Nil(b4) b5, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{ @@ -208,7 +378,7 @@ func TestBox(t *testing.T) { Version: "9.9.9", Provider: "virtualbox", }) - require.Error(err) + require.NoError(err) require.Nil(b5) b6, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{ diff --git a/internal/server/singleprocess/state/config_test.go b/internal/server/singleprocess/state/config_test.go index 0cf34ee18..dbaa10614 100644 --- a/internal/server/singleprocess/state/config_test.go +++ b/internal/server/singleprocess/state/config_test.go @@ -15,7 +15,7 @@ func TestConfig(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) + projRef := testProjectProto(t, s) // Create a build require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{ @@ -73,7 +73,7 @@ func TestConfig(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) + projRef := testProjectProto(t, s) // Create a build require.NoError(s.ConfigSet( @@ -129,7 +129,7 @@ func TestConfig(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) + projRef := testProjectProto(t, s) // Create a var require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{ @@ -184,7 +184,7 @@ func TestConfig(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) + projRef := testProjectProto(t, s) // Create the config require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{ @@ -256,7 +256,7 @@ func TestConfig(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) + projRef := testProjectProto(t, s) // Create the config require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{ @@ -369,7 +369,7 @@ func TestConfigWatch(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) + projRef := testProjectProto(t, s) ws := memdb.NewWatchSet() diff --git a/internal/server/singleprocess/state/decoding_test.go b/internal/server/singleprocess/state/decoding_test.go index 74fec0565..1d40a1426 100644 --- a/internal/server/singleprocess/state/decoding_test.go +++ b/internal/server/singleprocess/state/decoding_test.go @@ -25,7 +25,7 @@ func TestSoftDecode(t *testing.T) { s := TestState(t) defer s.Close() - pref := testProject(t, s) + pref := testProjectProto(t, s) tproto := &vagrant_server.Target{ Project: pref, } @@ -35,6 +35,6 @@ func TestSoftDecode(t *testing.T) { require.NoError(err) require.NotNil(target.Project) - require.Equal(*target.Project.ResourceId, tproto.Project.ResourceId) + require.Equal(target.Project.ResourceId, tproto.Project.ResourceId) }) } diff --git a/internal/server/singleprocess/state/job_test.go b/internal/server/singleprocess/state/job_test.go index 499dba5da..12796d37e 100644 --- a/internal/server/singleprocess/state/job_test.go +++ b/internal/server/singleprocess/state/job_test.go @@ -23,11 +23,11 @@ func TestJobAssign(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Create a build - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -59,11 +59,11 @@ func TestJobAssign(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Create a build - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -99,7 +99,7 @@ func TestJobAssign(t *testing.T) { } // Insert another job - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "B", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -277,18 +277,18 @@ func TestJobAssign(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Create two builds slightly apart - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, }, }))) time.Sleep(1 * time.Millisecond) - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "B", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -322,12 +322,12 @@ func TestJobAssign(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) - testRunner(t, s, &vagrant_server.Runner{Id: "R_B"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_B"}) // Create a build by ID - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -341,11 +341,11 @@ func TestJobAssign(t *testing.T) { }, }))) time.Sleep(1 * time.Millisecond) - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "B", }))) time.Sleep(1 * time.Millisecond) - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "C", }))) @@ -379,12 +379,12 @@ func TestJobAssign(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_B"}) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_B"}) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Create a build by ID - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -425,10 +425,10 @@ func TestJobAssign(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) + projRef := testProjectProto(t, s) r := &vagrant_server.Runner{Id: "R_A", ByIdOnly: true} - testRunner(t, s, r) + testRunnerProto(t, s, r) // Create a build require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ @@ -447,7 +447,7 @@ func TestJobAssign(t *testing.T) { require.Equal(ctx.Err(), err) // Create a target - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "B", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -476,11 +476,11 @@ func TestJobAck(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Create a build - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -512,11 +512,11 @@ func TestJobAck(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Create a build - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -553,11 +553,11 @@ func TestJobAck(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Create a build - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -591,11 +591,11 @@ func TestJobComplete(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Create a build - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -632,11 +632,11 @@ func TestJobComplete(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Create a build - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -676,10 +676,10 @@ func TestJobIsAssignable(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) + projRef := testProjectProto(t, s) // Create a build - result, err := s.JobIsAssignable(ctx, testJob(t, &vagrant_server.Job{ + result, err := s.JobIsAssignable(ctx, testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -695,11 +695,11 @@ func TestJobIsAssignable(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Should be assignable - result, err := s.JobIsAssignable(ctx, testJob(t, &vagrant_server.Job{ + result, err := s.JobIsAssignable(ctx, testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -720,11 +720,11 @@ func TestJobIsAssignable(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A", ByIdOnly: true}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A", ByIdOnly: true}) // Should be assignable - result, err := s.JobIsAssignable(ctx, testJob(t, &vagrant_server.Job{ + result, err := s.JobIsAssignable(ctx, testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -745,11 +745,11 @@ func TestJobIsAssignable(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_B"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_B"}) // Should be assignable - result, err := s.JobIsAssignable(ctx, testJob(t, &vagrant_server.Job{ + result, err := s.JobIsAssignable(ctx, testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -772,11 +772,11 @@ func TestJobIsAssignable(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Should be assignable - result, err := s.JobIsAssignable(ctx, testJob(t, &vagrant_server.Job{ + result, err := s.JobIsAssignable(ctx, testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -800,10 +800,10 @@ func TestJobCancel(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) + projRef := testProjectProto(t, s) // Create a build - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -826,11 +826,11 @@ func TestJobCancel(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Create a build - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -859,11 +859,11 @@ func TestJobCancel(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Create a build - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -892,11 +892,11 @@ func TestJobCancel(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Create a build - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -921,7 +921,7 @@ func TestJobCancel(t *testing.T) { require.NotEmpty(job.CancelTime) // Create a another job - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "B", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -945,11 +945,11 @@ func TestJobCancel(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Create a build - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -987,8 +987,8 @@ func TestJobHeartbeat(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Set a short timeout old := jobHeartbeatTimeout @@ -996,7 +996,7 @@ func TestJobHeartbeat(t *testing.T) { jobHeartbeatTimeout = 5 * time.Millisecond // Create a build - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -1035,11 +1035,11 @@ func TestJobHeartbeat(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Create a build - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -1104,11 +1104,11 @@ func TestJobHeartbeat(t *testing.T) { s := TestState(t) defer s.Close() - projRef := testProject(t, s) - testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + projRef := testProjectProto(t, s) + testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // Create a build - require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ Id: "A", Scope: &vagrant_server.Job_Project{ Project: projRef, @@ -1184,11 +1184,11 @@ func TestJobHeartbeat(t *testing.T) { // s := TestState(t) // defer s.Close() - // projRef := testProject(t, s) - // testRunner(t, s, &vagrant_server.Runner{Id: "R_A"}) + // projRef := testProjectProto(t, s) + // testRunnerProto(t, s, &vagrant_server.Runner{Id: "R_A"}) // // Create a build - // require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{ + // require.NoError(s.JobCreate(testJobProto(t, &vagrant_server.Job{ // Id: "A", // Scope: &vagrant_server.Job_Project{ // Project: projRef, diff --git a/internal/server/singleprocess/state/project.go b/internal/server/singleprocess/state/project.go index ad88ef5c3..d3a13fc3c 100644 --- a/internal/server/singleprocess/state/project.go +++ b/internal/server/singleprocess/state/project.go @@ -8,6 +8,7 @@ import ( "github.com/hashicorp/vagrant/internal/server" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" "gorm.io/gorm" + "gorm.io/gorm/clause" ) func init() { @@ -17,27 +18,72 @@ func init() { type Project struct { Model - Basis *Basis + Basis *Basis `gorm:"constraint:OnDelete:SET NULL"` BasisID uint `gorm:"uniqueIndex:idx_bname" mapstructure:"-"` - Vagrantfile *Vagrantfile `gorm:"OnDelete:Cascade" mapstructure:"Configuration"` - VagrantfileID *uint `mapstructure:"-"` + Vagrantfile *Vagrantfile `gorm:"constraint:OnDelete:SET NULL" mapstructure:"Configuration"` + VagrantfileID *uint `mapstructure:"-" gorm:"constraint:OnDelete:SET NULL"` DataSource *ProtoValue Jobs []*InternalJob `gorm:"polymorphic:Scope"` Metadata MetadataSet - Name *string `gorm:"uniqueIndex:idx_bname,not null"` - Path *string `gorm:"uniqueIndex,not null"` + Name string `gorm:"uniqueIndex:idx_bname,not null"` + Path string `gorm:"uniqueIndex,not null"` RemoteEnabled bool - ResourceId *string `gorm:"<-:create,uniqueIndex,not null"` - Targets []*Target `gorm:"OnDelete:Cascade"` + ResourceId string `gorm:"<-:create,uniqueIndex,not null"` + Targets []*Target } func (p *Project) scope() interface{} { return p } +func (p *Project) find(db *gorm.DB) (*Project, error) { + var project Project + result := db.Preload(clause.Associations). + Where(&Project{ResourceId: p.ResourceId}). + Or(&Project{BasisID: p.BasisID, Name: p.Name}). + Or(&Project{BasisID: p.BasisID, Path: p.Path}). + Or(&Project{Model: Model{ID: p.ID}}). + First(&project) + if result.Error != nil { + return nil, result.Error + } + + return &project, nil +} + +// Use before delete hook to remove all assocations +func (p *Project) BeforeDelete(tx *gorm.DB) error { + project, err := p.find(tx) + if err != nil { + return err + } + + if project.VagrantfileID != nil { + result := tx.Where(&Vagrantfile{Model: Model{ID: *project.VagrantfileID}}). + Delete(&Vagrantfile{}) + if result.Error != nil { + return result.Error + } + } + + if len(project.Targets) > 0 { + if result := tx.Delete(project.Targets); result.Error != nil { + return result.Error + } + } + + if len(project.Jobs) > 0 { + if result := tx.Delete(project.Jobs); result.Error != nil { + return result.Error + } + } + + return nil +} + // Set a public ID on the project before creating func (p *Project) BeforeSave(tx *gorm.DB) error { - if p.ResourceId == nil { + if p.ResourceId == "" { if err := p.setId(); err != nil { return err } @@ -58,17 +104,37 @@ func (p *Project) BeforeUpdate(tx *gorm.DB) error { return result.Error } id := v.ID - if err := decode(p, &v); err != nil { + if err := decode(p.Vagrantfile, &v); err != nil { return err } v.ID = id p.Vagrantfile = &v + + // NOTE: Just updating the value doesn't save the changes so + // save the changes in this transaction + if result := tx.Save(&v); result.Error != nil { + return result.Error + } } return nil } func (p *Project) Validate(tx *gorm.DB) error { - err := validation.ValidateStruct(p, + existing, err := p.find(tx) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + if existing == nil { + existing = &Project{} + } + + basisID := p.BasisID + if p.Basis != nil { + basisID = p.Basis.ID + } + + err = validation.ValidateStruct(p, validation.Field(&p.BasisID, validation.Required.When(p.Basis == nil), ), @@ -80,7 +146,7 @@ func (p *Project) Validate(tx *gorm.DB) error { validation.By( checkUnique( tx.Model(&Project{}). - Where(&Project{Name: p.Name, BasisID: p.BasisID}). + Where(&Project{Name: p.Name, BasisID: basisID}). Not(&Project{Model: Model{ID: p.ID}}), ), ), @@ -90,7 +156,7 @@ func (p *Project) Validate(tx *gorm.DB) error { validation.By( checkUnique( tx.Model(&Project{}). - Where(&Project{Path: p.Path, BasisID: p.BasisID}). + Where(&Project{Path: p.Path, BasisID: basisID}). Not(&Project{Model: Model{ID: p.ID}}), ), ), @@ -104,6 +170,14 @@ func (p *Project) Validate(tx *gorm.DB) error { Not(&Project{Model: Model{ID: p.ID}}), ), ), + validation.When( + p.ID != 0, + validation.By( + checkNotModified( + existing.ResourceId, + ), + ), + ), ), ) @@ -119,7 +193,7 @@ func (p *Project) setId() error { if err != nil { return err } - p.ResourceId = &id + p.ResourceId = id return nil } @@ -173,7 +247,7 @@ func (s *State) ProjectFromProtoRef( var project Project result := s.search().First(&project, - &Project{ResourceId: &ref.ResourceId}) + &Project{ResourceId: ref.ResourceId}) if result.Error != nil { return nil, result.Error } @@ -201,14 +275,14 @@ func (s *State) ProjectFromProtoRefFuzzy( query := &Project{} if ref.Name != "" { - query.Name = &ref.Name + query.Name = ref.Name } if ref.Path != "" { - query.Path = &ref.Path + query.Path = ref.Path } result := s.search(). - Joins("Basis", &Basis{ResourceId: &ref.Basis.ResourceId}). + Joins("Basis", &Basis{ResourceId: ref.Basis.ResourceId}). Where(query). First(project) diff --git a/internal/server/singleprocess/state/project_test.go b/internal/server/singleprocess/state/project_test.go index 7523965dd..11b7a2f3b 100644 --- a/internal/server/singleprocess/state/project_test.go +++ b/internal/server/singleprocess/state/project_test.go @@ -10,9 +10,454 @@ import ( "github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" serverptypes "github.com/hashicorp/vagrant/internal/server/ptypes" + "gorm.io/gorm" ) -func TestProject(t *testing.T) { +func TestProject_Create(t *testing.T) { + t.Run("Requires name, path, and basis", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Project{}) + require.Error(result.Error) + require.ErrorContains(result.Error, "Name:") + require.ErrorContains(result.Error, "Path:") + require.ErrorContains(result.Error, "Basis:") + }) + + t.Run("Requires name", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save( + &Project{ + Path: "/dev/null", + Basis: testBasis(t, db), + }, + ) + require.Error(result.Error) + require.ErrorContains(result.Error, "Name:") + }) + + t.Run("Requires path", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save( + &Project{ + Name: "default", + Basis: testBasis(t, db), + }, + ) + require.Error(result.Error) + require.ErrorContains(result.Error, "Path:") + }) + + t.Run("Requires basis", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save( + &Project{ + Name: "default", + Path: "/dev/null", + }, + ) + require.Error(result.Error) + require.ErrorContains(result.Error, "Basis:") + }) + + t.Run("Sets resource ID", func(t *testing.T) { + require, db := requireAndDB(t) + + project := Project{ + Name: "default", + Path: "/dev/null", + Basis: testBasis(t, db), + } + result := db.Save(&project) + require.NoError(result.Error) + require.NotEmpty(project.ResourceId) + }) + + t.Run("Retains resource ID", func(t *testing.T) { + require, db := requireAndDB(t) + + rid := "RESOURCE_ID" + project := Project{ + Name: "default", + Path: "/dev/null", + ResourceId: rid, + Basis: testBasis(t, db), + } + result := db.Save(&project) + require.NoError(result.Error) + require.EqualValues(rid, project.ResourceId) + }) + + t.Run("Does not allow duplicate name in same basis", func(t *testing.T) { + require, db := requireAndDB(t) + + basis := testBasis(t, db) + result := db.Save( + &Project{ + Name: "default", + Path: "/dev/null", + Basis: basis, + }, + ) + require.NoError(result.Error) + result = db.Save( + &Project{ + Name: "default", + Path: "/dev/null/other", + Basis: basis, + }, + ) + require.Error(result.Error) + require.ErrorContains(result.Error, "Name:") + }) + + t.Run("Allows duplicate name in different basis", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save( + &Project{ + Name: "default", + Path: "/dev/null", + Basis: testBasis(t, db), + }, + ) + require.NoError(result.Error) + result = db.Save( + &Project{ + Name: "default", + Path: "/dev/null/other", + Basis: testBasis(t, db), + }, + ) + require.NoError(result.Error) + }) + + t.Run("Does not allow duplicate path in same basis", func(t *testing.T) { + require, db := requireAndDB(t) + + basis := testBasis(t, db) + result := db.Save( + &Project{ + Name: "default", + Path: "/dev/null", + Basis: basis, + }, + ) + require.NoError(result.Error) + result = db.Save( + &Project{ + Name: "other", + Path: "/dev/null", + Basis: basis, + }, + ) + require.Error(result.Error) + require.ErrorContains(result.Error, "Path:") + }) + + t.Run("Allows duplicate path in different basis", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save( + &Project{ + Name: "default", + Path: "/dev/null", + Basis: testBasis(t, db), + }, + ) + require.NoError(result.Error) + result = db.Save( + &Project{ + Name: "other", + Path: "/dev/null", + Basis: testBasis(t, db), + }, + ) + require.NoError(result.Error) + }) + + t.Run("Does not allow duplicate resource IDs", func(t *testing.T) { + require, db := requireAndDB(t) + + rid := "RESOURCE ID" + result := db.Save( + &Project{ + Name: "default", + Path: "/dev/null", + ResourceId: rid, + Basis: testBasis(t, db), + }, + ) + require.NoError(result.Error) + result = db.Save( + &Project{ + Name: "other", + Path: "/dev/null/other", + ResourceId: rid, + Basis: testBasis(t, db), + }, + ) + require.Error(result.Error) + require.ErrorContains(result.Error, "ResourceId:") + }) + + t.Run("Creates Vagrantfile when set", func(t *testing.T) { + require, db := requireAndDB(t) + + vagrantfile := Vagrantfile{} + project := Project{ + Name: "default", + Path: "/dev/null", + Basis: testBasis(t, db), + Vagrantfile: &vagrantfile, + } + result := db.Save(&project) + require.NoError(result.Error) + require.NotNil(project.VagrantfileID) + require.Equal(*project.VagrantfileID, vagrantfile.ID) + }) +} + +func TestProject_Update(t *testing.T) { + t.Run("Requires name and path", func(t *testing.T) { + require, db := requireAndDB(t) + + project := &Project{ + Name: "default", + Path: "/dev/null", + Basis: testBasis(t, db), + } + result := db.Save(project) + require.NoError(result.Error) + + project.Name = "" + project.Path = "" + result = db.Save(project) + require.Error(result.Error) + require.ErrorContains(result.Error, "Name:") + require.ErrorContains(result.Error, "Path:") + }) + + t.Run("Requires name", func(t *testing.T) { + require, db := requireAndDB(t) + + project := &Project{ + Name: "default", + Path: "/dev/null", + Basis: testBasis(t, db), + } + result := db.Save(project) + require.NoError(result.Error) + project.Name = "" + result = db.Save(project) + require.Error(result.Error) + require.ErrorContains(result.Error, "Name:") + }) + + t.Run("Requires path", func(t *testing.T) { + require, db := requireAndDB(t) + + project := &Project{ + Name: "default", + Path: "/dev/null", + Basis: testBasis(t, db), + } + result := db.Save(project) + require.NoError(result.Error) + project.Path = "" + result = db.Save(project) + require.Error(result.Error) + require.ErrorContains(result.Error, "Path:") + }) + + t.Run("Requires basis", func(t *testing.T) { + require, db := requireAndDB(t) + + project := &Project{ + Name: "default", + Path: "/dev/null", + Basis: testBasis(t, db), + } + result := db.Save(project) + require.NoError(result.Error) + project.Basis = nil + project.BasisID = 0 + result = db.Save(project) + require.Error(result.Error) + }) + + t.Run("Does not update resource ID", func(t *testing.T) { + require, db := requireAndDB(t) + + project := Project{ + Name: "default", + Path: "/dev/null", + Basis: testBasis(t, db), + } + result := db.Save(&project) + require.NoError(result.Error) + require.NotNil(project.ResourceId) + require.NotEmpty(project.ResourceId) + + var reloadProject Project + result = db.First(&reloadProject, &Project{Model: Model{ID: project.ID}}) + require.NoError(result.Error) + + reloadProject.ResourceId = "NEW VALUE" + result = db.Save(&reloadProject) + require.Error(result.Error) + require.ErrorContains(result.Error, "ResourceId:") + }) + + t.Run("Adds Vagrantfile", func(t *testing.T) { + require, db := requireAndDB(t) + + vpath := "/dev/null/Vagrantfile" + project := Project{ + Name: "default", + Path: "/dev/null", + Basis: testBasis(t, db), + } + result := db.Save(&project) + require.NoError(result.Error) + v := &Vagrantfile{Path: &vpath} + project.Vagrantfile = v + result = db.Save(&project) + require.NoError(result.Error) + require.NotEmpty(v.ID) + }) + + t.Run("Updates existing Vagrantfile content", func(t *testing.T) { + require, db := requireAndDB(t) + + // Create inital basis + vpath := "/dev/null/Vagrantfile" + v := &Vagrantfile{Path: &vpath} + project := Project{ + Name: "default", + Path: "/dev/null", + Vagrantfile: v, + Basis: testBasis(t, db), + } + result := db.Save(&project) + require.NoError(result.Error) + require.NotEmpty(v.ID) + originalID := v.ID + + // Update with new Vagrantfile + newPath := "/dev/null/new" + newV := &Vagrantfile{Path: &newPath} + project.Vagrantfile = newV + result = db.Save(&project) + require.NoError(result.Error) + require.Equal(*project.Vagrantfile.Path, newPath) + require.Equal(originalID, project.Vagrantfile.ID) + + // Refetch Vagrantfile to ensure persisted changes + var checkVF Vagrantfile + result = db.First(&checkVF, &Vagrantfile{Model: Model{ID: originalID}}) + require.NoError(result.Error) + require.Equal(*checkVF.Path, newPath) + + // Validate only one Vagrantfile has been stored + var count int64 + result = db.Model(&Vagrantfile{}).Count(&count) + require.NoError(result.Error) + require.Equal(int64(1), count) + }) +} + +func TestProject_Delete(t *testing.T) { + t.Run("Deletes project", func(t *testing.T) { + require, db := requireAndDB(t) + + seedProject := testProject(t, db) + + var project Project + result := db.First(&project, + &Project{ + Name: seedProject.Name, + Path: seedProject.Path, + }, + ) + require.NoError(result.Error) + + result = db.Where(&Project{ResourceId: project.ResourceId}). + Delete(&Project{}) + require.NoError(result.Error) + result = db.First(&Project{}, &Project{ResourceId: project.ResourceId}) + require.Error(result.Error) + require.ErrorIs(result.Error, gorm.ErrRecordNotFound) + }) + + t.Run("Deletes Vagrantfile", func(t *testing.T) { + require, db := requireAndDB(t) + + vpath := "/dev/null/Vagrantfile" + result := db.Save(&Project{ + Name: "default", + Path: "/dev/null", + Basis: testBasis(t, db), + Vagrantfile: &Vagrantfile{Path: &vpath}, + }) + require.NoError(result.Error) + + var count int64 + result = db.Model(&Vagrantfile{}).Count(&count) + require.NoError(result.Error) + require.Equal(int64(1), count) + + result = db.Where(&Project{Name: "default"}). + Delete(&Basis{}) + require.NoError(result.Error) + result = db.Model((*Vagrantfile)(nil)).Count(&count) + require.NoError(result.Error) + require.Equal(int64(0), count) + }) + + t.Run("Deletes targets", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Project{ + Name: "default", + Path: "/dev/null", + Basis: testBasis(t, db), + Targets: []*Target{ + { + Name: "default", + }, + { + Name: "Other", + }, + }, + }) + require.NoError(result.Error) + + var count int64 + result = db.Model(&Project{}).Count(&count) + require.NoError(result.Error) + require.Equal(int64(1), count) + result = db.Model(&Target{}).Count(&count) + require.NoError(result.Error) + require.Equal(int64(2), count) + + result = db.Where(&Project{Name: "default"}). + Delete(&Project{}) + require.NoError(result.Error) + + result = db.Model(&Project{}).Count(&count) + require.NoError(result.Error) + require.Equal(int64(0), count) + result = db.Model(&Target{}).Count(&count) + require.NoError(result.Error) + require.Equal(int64(0), count) + }) +} + +func TestProject_State(t *testing.T) { t.Run("Get returns not found error if not exist", func(t *testing.T) { require := require.New(t) @@ -32,7 +477,7 @@ func TestProject(t *testing.T) { s := TestState(t) defer s.Close() - basisRef := testBasis(t, s) + basisRef := testBasisProto(t, s) // Set result, err := s.ProjectPut(serverptypes.TestProject(t, &vagrant_server.Project{ @@ -65,7 +510,7 @@ func TestProject(t *testing.T) { s := TestState(t) defer s.Close() - basisRef := testBasis(t, s) + basisRef := testBasisProto(t, s) // Set result, err := s.ProjectPut(serverptypes.TestProject(t, &vagrant_server.Project{ diff --git a/internal/server/singleprocess/state/target.go b/internal/server/singleprocess/state/target.go index fd4c4dbfd..2fda76cf8 100644 --- a/internal/server/singleprocess/state/target.go +++ b/internal/server/singleprocess/state/target.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/vagrant/internal/server" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" "gorm.io/gorm" + "gorm.io/gorm/clause" ) func init() { @@ -21,26 +22,57 @@ type Target struct { Configuration *ProtoValue Jobs []*InternalJob `gorm:"polymorphic:Scope;" mapstructure:"-"` Metadata MetadataSet - Name *string `gorm:"uniqueIndex:idx_pname;not null"` + Name string `gorm:"uniqueIndex:idx_pname;not null"` Parent *Target `gorm:"foreignkey:ID"` ParentID *uint `mapstructure:"-"` Project *Project ProjectID uint `gorm:"uniqueIndex:idx_pname" mapstructure:"-"` Provider *string Record *ProtoValue - ResourceId *string `gorm:"<-:create;uniqueIndex;not null"` + ResourceId string `gorm:"uniqueIndex"` State vagrant_server.Operation_PhysicalState - Subtargets []*Target `gorm:"foreignkey:ParentID"` + Subtargets []*Target `gorm:"foreignkey:ParentID;constraint:OnDelete:SET NULL"` Uuid *string `gorm:"uniqueIndex"` } +func (t *Target) find(db *gorm.DB) (*Target, error) { + var target Target + result := db.Preload(clause.Associations). + Where(&Target{ResourceId: t.ResourceId}). + Or(&Target{Uuid: t.Uuid}). + Or(&Target{ProjectID: t.ProjectID, Name: t.Name}). + Or(&Target{Model: Model{ID: t.ID}}). + First(&target) + if result.Error != nil { + return nil, result.Error + } + + return &target, nil +} + func (t *Target) scope() interface{} { return t } +// Use before delete hook to remove all associations +func (t *Target) BeforeDelete(tx *gorm.DB) error { + target, err := t.find(tx) + if err != nil { + return err + } + + if len(target.Subtargets) > 0 { + if result := tx.Delete(target.Subtargets); result.Error != nil { + return result.Error + } + } + + return nil +} + // Set a public ID on the target before creating func (t *Target) BeforeSave(tx *gorm.DB) error { - if t.ResourceId == nil { + if t.ResourceId == "" { if err := t.setId(); err != nil { return err } @@ -53,14 +85,48 @@ func (t *Target) BeforeSave(tx *gorm.DB) error { return nil } +// NOTE: Need better validation on parent <-> subtarget +// project matching. It currently does basic check but +// will miss edge cases easily. func (t *Target) validate(tx *gorm.DB) error { - err := validation.ValidateStruct(t, + existing, err := t.find(tx) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + if existing == nil { + existing = &Target{} + } + + projectID := t.ProjectID + if t.Project != nil { + projectID = t.Project.ID + } + + parent := &Target{} + parentProjectID := uint(0) + if t.Parent != nil { + parent = t.Parent + } else if t.ParentID != nil { + result := tx.First(parent, &Target{Model: Model{ID: *t.ParentID}}) + if result.Error != nil { + return result.Error + } + } + if parent != nil { + parentProjectID = parent.ProjectID + if parent.Project != nil { + parentProjectID = parent.Project.ID + } + } + + err = validation.ValidateStruct(t, validation.Field(&t.Name, validation.Required, validation.By( checkUnique( tx.Model(&Target{}). - Where(&Target{Name: t.Name, ProjectID: t.ProjectID}). + Where(&Target{Name: t.Name, ProjectID: projectID}). Not(&Target{Model: Model{ID: t.ID}}), ), ), @@ -88,9 +154,21 @@ func (t *Target) validate(tx *gorm.DB) error { ), validation.Field(&t.ProjectID, validation.Required.When(t.Project == nil), + validation.When( + t.ProjectID != 0 && parentProjectID != 0, + validation.By( + checkSameProject(parentProjectID), + ), + ), ), validation.Field(&t.Project, validation.Required.When(t.ProjectID == 0), + validation.When( + t.Project != nil && parentProjectID != 0, + validation.By( + checkSameProject(parentProjectID), + ), + ), ), ) @@ -106,7 +184,7 @@ func (t *Target) setId() error { if err != nil { return err } - t.ResourceId = &id + t.ResourceId = id return nil } @@ -157,7 +235,7 @@ func (s *State) TargetFromProtoRef( var target Target result := s.search().Preload("Project.Basis").First(&target, - &Target{ResourceId: &ref.ResourceId}, + &Target{ResourceId: ref.ResourceId}, ) if result.Error != nil { return nil, result.Error @@ -184,9 +262,9 @@ func (s *State) TargetFromProtoRefFuzzy( target = &Target{} result := s.search(). - Joins("Project", &Project{ResourceId: &ref.Project.ResourceId}). + Joins("Project", &Project{ResourceId: ref.Project.ResourceId}). Preload("Project.Basis"). - First(target, &Target{Name: &ref.Name}) + First(target, &Target{Name: ref.Name}) if result.Error != nil { return nil, result.Error @@ -239,7 +317,7 @@ func (s *State) TargetFromProtoFuzzy( Preload("Project.Basis"). Where("Project.resource_id = ?", t.Project.ResourceId) - result := tx.First(target, &Target{Name: &t.Name}) + result := tx.First(target, &Target{Name: t.Name}) if result.Error != nil { return nil, result.Error } diff --git a/internal/server/singleprocess/state/target_test.go b/internal/server/singleprocess/state/target_test.go index d5e0cc0f7..d72f7e59f 100644 --- a/internal/server/singleprocess/state/target_test.go +++ b/internal/server/singleprocess/state/target_test.go @@ -8,12 +8,351 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/wrapperspb" + "gorm.io/gorm" + "gorm.io/gorm/clause" "github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" ) -func TestTarget(t *testing.T) { +func TestTarget_Create(t *testing.T) { + t.Run("Requires name and project", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Target{}) + require.Error(result.Error) + require.ErrorContains(result.Error, "Name:") + require.ErrorContains(result.Error, "Project:") + }) + + t.Run("Requires name", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save( + &Target{ + Project: testProject(t, db), + }, + ) + require.Error(result.Error) + require.ErrorContains(result.Error, "Name:") + }) + + t.Run("Requires project", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save( + &Target{ + Name: "default", + }, + ) + require.Error(result.Error) + require.ErrorContains(result.Error, "Project:") + }) + + t.Run("Sets resource ID", func(t *testing.T) { + require, db := requireAndDB(t) + + target := Target{ + Name: "default", + Project: testProject(t, db), + } + result := db.Save(&target) + require.NoError(result.Error) + require.NotEmpty(target.ResourceId) + }) + + t.Run("Retains resource ID", func(t *testing.T) { + require, db := requireAndDB(t) + + rid := "RESOURCE_ID" + target := Target{ + Name: "default", + ResourceId: rid, + Project: testProject(t, db), + } + result := db.Save(&target) + require.NoError(result.Error) + require.NotNil(target.ResourceId) + require.EqualValues(rid, target.ResourceId) + }) + + t.Run("Does not allow duplicate name in same project", func(t *testing.T) { + require, db := requireAndDB(t) + + project := testProject(t, db) + result := db.Save( + &Target{ + Name: "default", + Project: project, + }, + ) + require.NoError(result.Error) + result = db.Save( + &Target{ + Name: "default", + Project: project, + }, + ) + require.Error(result.Error) + require.ErrorContains(result.Error, "Name:") + }) + + t.Run("Allows duplicate name in different projects", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save( + &Target{ + Name: "default", + Project: testProject(t, db), + }, + ) + require.NoError(result.Error) + result = db.Save( + &Target{ + Name: "default", + Project: testProject(t, db), + }, + ) + require.NoError(result.Error) + }) + + t.Run("Does not allow duplicate resource IDs", func(t *testing.T) { + require, db := requireAndDB(t) + + rid := "RESOURCE ID" + result := db.Save( + &Target{ + Name: "default", + ResourceId: rid, + Project: testProject(t, db), + }, + ) + require.NoError(result.Error) + result = db.Save( + &Target{ + Name: "other", + ResourceId: rid, + Project: testProject(t, db), + }, + ) + require.Error(result.Error) + require.ErrorContains(result.Error, "ResourceId:") + }) + + t.Run("Does not allow duplicate UUIDs", func(t *testing.T) { + require, db := requireAndDB(t) + + uuid := "UUID VALUE" + result := db.Save( + &Target{ + Name: "default", + Uuid: &uuid, + Project: testProject(t, db), + }, + ) + require.NoError(result.Error) + result = db.Save( + &Target{ + Name: "other", + Uuid: &uuid, + Project: testProject(t, db), + }, + ) + require.Error(result.Error) + require.ErrorContains(result.Error, "Uuid:") + }) + + t.Run("Stores a record when set", func(t *testing.T) { + require, db := requireAndDB(t) + + record := &vagrant_server.Target_Machine{ + Id: "MACHINE_ID", + } + result := db.Save( + &Target{ + Name: "default", + Project: testProject(t, db), + Record: &ProtoValue{Message: record}, + }, + ) + require.NoError(result.Error) + var target Target + result = db.First(&target, &Target{Name: "default"}) + require.NoError(result.Error) + require.Equal(record.Id, target.Record.Message.(*vagrant_server.Target_Machine).Id) + }) + + t.Run("Properly creates child targets", func(t *testing.T) { + require, db := requireAndDB(t) + + project := testProject(t, db) + result := db.Save( + &Target{ + Name: "parent", + Project: project, + Subtargets: []*Target{ + { + Name: "subtarget1", + Project: project, + }, + { + Name: "subtarget2", + Project: project, + }, + { + Name: "subtarget3", + Project: project, + }, + }, + }, + ) + require.NoError(result.Error) + var target Target + result = db.Preload(clause.Associations). + First(&target, &Target{Name: "parent"}) + require.NoError(result.Error) + require.Equal(3, len(target.Subtargets)) + }) +} + +func TestTarget_Update(t *testing.T) { + t.Run("Requires name", func(t *testing.T) { + require, db := requireAndDB(t) + + target := &Target{Name: "default", Project: testProject(t, db)} + result := db.Save(target) + require.NoError(result.Error) + + target.Name = "" + result = db.Save(target) + require.Error(result.Error) + require.ErrorContains(result.Error, "Name:") + }) + + t.Run("Does not update resource ID", func(t *testing.T) { + require, db := requireAndDB(t) + + target := Target{Name: "default", Project: testProject(t, db)} + result := db.Save(&target) + require.NoError(result.Error) + require.NotEmpty(target.ResourceId) + + var reloadTarget Basis + result = db.First(&reloadTarget, &Target{Model: Model{ID: target.ID}}) + require.NoError(result.Error) + + reloadTarget.ResourceId = "NEW VALUE" + result = db.Save(&reloadTarget) + require.Error(result.Error) + require.ErrorContains(result.Error, "ResourceId:") + }) + + t.Run("Adds subtarget", func(t *testing.T) { + require, db := requireAndDB(t) + + project := testProject(t, db) + target := Target{ + Name: "parent", + Project: project, + Subtargets: []*Target{ + { + Name: "subtarget1", + Project: project, + }, + }, + } + result := db.Save(&target) + require.NoError(result.Error) + result = db.Preload(clause.Associations).First(&target, &Target{Name: "parent"}) + require.NoError(result.Error) + require.Equal(1, len(target.Subtargets)) + target.Subtargets = append(target.Subtargets, &Target{ + Name: "subtarget2", + Project: project, + }) + result = db.Save(&target) + require.NoError(result.Error) + result = db.Preload(clause.Associations).First(&target, &Target{Name: "parent"}) + require.NoError(result.Error) + require.Equal(2, len(target.Subtargets)) + }) + + t.Run("It fails to add subtarget with different project", func(t *testing.T) { + require, db := requireAndDB(t) + + target := Target{ + Name: "parent", + Project: testProject(t, db), + } + result := db.Save(&target) + require.NoError(result.Error) + result = db.First(&target, &Target{Name: "parent"}) + require.NoError(result.Error) + target.Subtargets = append(target.Subtargets, &Target{ + Name: "subtarget", + Project: testProject(t, db), + }) + result = db.Save(&target) + require.Error(result.Error) + }) +} + +func TestTarget_Delete(t *testing.T) { + t.Run("Deletes target", func(t *testing.T) { + require, db := requireAndDB(t) + + result := db.Save(&Target{Name: "default", Project: testProject(t, db)}) + require.NoError(result.Error) + + var target Target + result = db.First(&target, &Target{Name: "default"}) + require.NoError(result.Error) + + result = db.Where(&Target{ResourceId: target.ResourceId}). + Delete(&Target{}) + require.NoError(result.Error) + result = db.First(&Target{}, &Target{ResourceId: target.ResourceId}) + require.Error(result.Error) + require.ErrorIs(result.Error, gorm.ErrRecordNotFound) + }) + + t.Run("Deletes subtargets", func(t *testing.T) { + require, db := requireAndDB(t) + + project := testProject(t, db) + result := db.Save( + &Target{ + Name: "parent", + Project: project, + Subtargets: []*Target{ + { + Name: "subtarget1", + Project: project, + }, + { + Name: "subtarget2", + Project: project, + }, + }, + }, + ) + require.NoError(result.Error) + + var count int64 + result = db.Model(&Target{}).Count(&count) + require.NoError(result.Error) + require.Equal(int64(3), count) + + result = db.Where(&Target{Name: "parent"}). + Delete(&Target{}) + require.NoError(result.Error) + result = db.Model(&Target{}).Count(&count) + require.NoError(result.Error) + require.Equal(int64(0), count) + }) +} + +func TestTarget_State(t *testing.T) { t.Run("Get returns not found error if not exist", func(t *testing.T) { require := require.New(t) @@ -33,7 +372,7 @@ func TestTarget(t *testing.T) { s := TestState(t) defer s.Close() - projectRef := testProject(t, s) + projectRef := testProjectProto(t, s) // Set result, err := s.TargetPut(&vagrant_server.Target{ @@ -67,7 +406,7 @@ func TestTarget(t *testing.T) { Project: projectRef, Name: "test", }) - require.Error(err) + require.NoError(err) // Ensure there is still one entry resp, err = s.TargetList() @@ -147,7 +486,7 @@ func TestTarget(t *testing.T) { s := TestState(t) defer s.Close() - projectRef := testProject(t, s) + projectRef := testProjectProto(t, s) // Set result, err := s.TargetPut(&vagrant_server.Target{ @@ -194,7 +533,7 @@ func TestTarget(t *testing.T) { s := TestState(t) defer s.Close() - projectRef := testProject(t, s) + projectRef := testProjectProto(t, s) // Set result, err := s.TargetPut(&vagrant_server.Target{ diff --git a/internal/server/singleprocess/state/testing.go b/internal/server/singleprocess/state/testing.go index d4d37fd57..38a8f485a 100644 --- a/internal/server/singleprocess/state/testing.go +++ b/internal/server/singleprocess/state/testing.go @@ -89,6 +89,10 @@ func testDB(t testing.T) *gorm.DB { db, err := gorm.Open(sqlite.Open(""), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) + db.Exec("PRAGMA foreign_keys = ON") + if err != nil { + panic("failed to enable foreign key constraints: " + err.Error()) + } require.NoError(t, err) t.Cleanup(func() { @@ -101,42 +105,58 @@ func testDB(t testing.T) *gorm.DB { return db } -// TestBasis creates the basis in the DB. -func testBasis(t testing.T, s *State) *vagrant_plugin_sdk.Ref_Basis { - t.Helper() - - td := testTempDir(t) - name := filepath.Base(td) - b := &Basis{ - Name: &name, - Path: &td, +func requireAndDB(t testing.T) (*require.Assertions, *gorm.DB) { + db := testDB(t) + require := require.New(t) + if err := db.AutoMigrate(models...); err != nil { + require.NoError(err) } - result := s.db.Save(b) - require.NoError(t, result.Error) - - return b.ToProtoRef() + return require, db } -func testProject(t testing.T, s *State) *vagrant_plugin_sdk.Ref_Project { +func testBasis(t testing.T, db *gorm.DB) *Basis { t.Helper() - basisRef := testBasis(t, s) - b, err := s.BasisFromProtoRef(basisRef) - require.NoError(t, err) td := testTempDir(t) - name := filepath.Base(td) + b := &Basis{ + Name: filepath.Base(td), + Path: td, + } + result := db.Save(b) + require.NoError(t, result.Error) + + return b +} + +// TestBasis creates the basis in the DB. +func testBasisProto(t testing.T, s *State) *vagrant_plugin_sdk.Ref_Basis { + t.Helper() + + return testBasis(t, s.db).ToProtoRef() +} + +func testProject(t testing.T, db *gorm.DB) *Project { + b := testBasis(t, db) + + td := testTempDir(t) p := &Project{ - Name: &name, - Path: &td, + Name: filepath.Base(td), + Path: td, Basis: b, } - result := s.db.Save(p) + result := db.Save(p) require.NoError(t, result.Error) - return p.ToProtoRef() + return p } -func testRunner(t testing.T, s *State, src *vagrant_server.Runner) *vagrant_server.Runner { +func testProjectProto(t testing.T, s *State) *vagrant_plugin_sdk.Ref_Project { + t.Helper() + + return testProject(t, s.db).ToProtoRef() +} + +func testRunnerProto(t testing.T, s *State, src *vagrant_server.Runner) *vagrant_server.Runner { t.Helper() if src == nil { @@ -155,7 +175,7 @@ func testRunner(t testing.T, s *State, src *vagrant_server.Runner) *vagrant_serv return runner.ToProto() } -func testJob(t testing.T, src *vagrant_server.Job) *vagrant_server.Job { +func testJobProto(t testing.T, src *vagrant_server.Job) *vagrant_server.Job { t.Helper() require.NoError(t, mergo.Merge(src, diff --git a/internal/server/singleprocess/state/validations.go b/internal/server/singleprocess/state/validations.go index 3f89b2be9..206a02eb9 100644 --- a/internal/server/singleprocess/state/validations.go +++ b/internal/server/singleprocess/state/validations.go @@ -1,14 +1,22 @@ package state import ( + "fmt" + "reflect" + "github.com/go-ozzo/ozzo-validation/v4" + "github.com/hashicorp/go-version" "gorm.io/gorm" ) type ValidationCode string const ( - VALIDATION_UNIQUE ValidationCode = "unique" + VALIDATION_UNIQUE ValidationCode = "unique" + VALIDATION_MODIFIED = "modified" + VALIDATION_PROJECT = "project" + VALIDATION_TYPE = "type" + VALIDATION_VERSION = "version" ) func checkUnique(tx *gorm.DB) validation.RuleFunc { @@ -29,3 +37,58 @@ func checkUnique(tx *gorm.DB) validation.RuleFunc { return nil } } + +func checkNotModified(original interface{}) validation.RuleFunc { + return func(value interface{}) error { + if !reflect.DeepEqual(original, value) { + return validation.NewError( + string(VALIDATION_MODIFIED), + "cannot be modified", + ) + } + + return nil + } +} + +func checkSameProject(projectID uint) validation.RuleFunc { + return func(value interface{}) error { + vPid, ok := value.(uint) + if !ok { + project, ok := value.(*Project) + if !ok { + return validation.NewError( + string(VALIDATION_TYPE), + fmt.Sprintf("*Project or uint required for validation (%T)", value), + ) + } + vPid = project.ID + } + if vPid != projectID { + return validation.NewError( + string(VALIDATION_PROJECT), + "project must match parent project", + ) + } + return nil + } +} + +func checkValidVersion(value interface{}) error { + v, ok := value.(string) + if !ok { + return validation.NewError( + string(VALIDATION_TYPE), + fmt.Sprintf("version string required for validation (%T)", value), + ) + } + _, err := version.NewVersion(v) + if err != nil { + return validation.NewError( + string(VALIDATION_VERSION), + "invalid version string", + ) + } + + return nil +}