Migrate data layer to gorm

This commit is contained in:
Chris Roberts 2022-09-19 08:40:11 -07:00
parent 0a0333adb7
commit f24ab4d855
24 changed files with 3870 additions and 2090 deletions

View File

@ -1,326 +1,340 @@
package state package state
import ( import (
"strings" "errors"
"google.golang.org/protobuf/proto"
"github.com/hashicorp/go-memdb"
bolt "go.etcd.io/bbolt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/go-ozzo/ozzo-validation/v4"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk" "github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
"gorm.io/gorm"
) )
var basisBucket = []byte("basis")
func init() { func init() {
dbBuckets = append(dbBuckets, basisBucket) models = append(models, &Basis{})
dbIndexers = append(dbIndexers, (*State).basisIndexInit)
schemas = append(schemas, basisIndexSchema)
} }
func (s *State) BasisFind(b *vagrant_server.Basis) (*vagrant_server.Basis, error) { // This interface is utilized internally as an
memTxn := s.inmem.Txn(false) // identifier for scopes to allow for easier mapping
defer memTxn.Abort() type scope interface {
scope() interface{}
}
var result *vagrant_server.Basis type Basis struct {
err := s.db.View(func(dbTxn *bolt.Tx) error { gorm.Model
var err error
result, err = s.basisFind(dbTxn, memTxn, b) Vagrantfile *Vagrantfile `mapstructure:"-"`
VagrantfileID uint `mapstructure:"-"`
DataSource *ProtoValue
Jobs []*InternalJob `gorm:"polymorphic:Scope;" mapstructure:"-"`
Metadata MetadataSet
Name *string `gorm:"uniqueIndex,not null"`
Path *string `gorm:"uniqueIndex,not null"`
Projects []*Project
RemoteEnabled bool
ResourceId *string `gorm:"<-:create;uniqueIndex;not null"`
}
func (b *Basis) scope() interface{} {
return b
}
// Define custom table name
func (Basis) TableName() string {
return "basis"
}
func (b *Basis) BeforeSave(tx *gorm.DB) error {
if b.ResourceId == nil {
if err := b.setId(); err != nil {
return err
}
}
if err := b.Validate(tx); err != nil {
return err return err
})
return result, err
}
func (s *State) BasisGet(ref *vagrant_plugin_sdk.Ref_Basis) (*vagrant_server.Basis, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
var result *vagrant_server.Basis
err := s.db.View(func(dbTxn *bolt.Tx) error {
var err error
result, err = s.basisGet(dbTxn, memTxn, ref)
return err
})
return result, err
}
func (s *State) BasisPut(b *vagrant_server.Basis) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) error {
return s.basisPut(dbTxn, memTxn, b)
})
if err == nil {
memTxn.Commit()
} }
return err return nil
} }
func (s *State) BasisDelete(ref *vagrant_plugin_sdk.Ref_Basis) error { func (b *Basis) Validate(tx *gorm.DB) error {
memTxn := s.inmem.Txn(true) err := validation.ValidateStruct(b,
defer memTxn.Abort() validation.Field(&b.Name,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Basis{}).
Where(&Basis{Name: b.Name}).
Not(&Basis{Model: gorm.Model{ID: b.ID}}),
),
),
),
validation.Field(&b.Path,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Basis{}).
Where(&Basis{Path: b.Path}).
Not(&Basis{Model: gorm.Model{ID: b.ID}}),
),
),
),
validation.Field(&b.ResourceId,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Basis{}).
Where(&Basis{ResourceId: b.ResourceId}).
Not(&Basis{Model: gorm.Model{ID: b.ID}}),
),
),
),
)
err := s.db.Update(func(dbTxn *bolt.Tx) error { if err != nil {
return s.basisDelete(dbTxn, memTxn, ref) return err
})
if err == nil {
memTxn.Commit()
} }
return err return nil
} }
func (s *State) BasisList() ([]*vagrant_plugin_sdk.Ref_Basis, error) { func (b *Basis) setId() error {
memTxn := s.inmem.Txn(false) id, err := server.Id()
defer memTxn.Abort() if err != nil {
return err
}
b.ResourceId = &id
return s.basisList(memTxn) return nil
} }
func (s *State) basisGet( // Convert basis to protobuf message
dbTxn *bolt.Tx, func (b *Basis) ToProto() *vagrant_server.Basis {
memTxn *memdb.Txn, if b == nil {
ref *vagrant_plugin_sdk.Ref_Basis, return nil
) (*vagrant_server.Basis, error) { }
var result vagrant_server.Basis
b := dbTxn.Bucket(basisBucket) basis := vagrant_server.Basis{}
return &result, dbGet(b, s.basisIdByRef(ref), &result) err := decode(b, &basis)
if err != nil {
panic("failed to decode basis: " + err.Error())
}
if b.Vagrantfile != nil {
basis.Configuration = b.Vagrantfile.ToProto()
}
return &basis
} }
func (s *State) basisFind( // Convert basis to reference protobuf message
dbTxn *bolt.Tx, func (b *Basis) ToProtoRef() *vagrant_plugin_sdk.Ref_Basis {
memTxn *memdb.Txn, if b == nil {
return nil
}
ref := vagrant_plugin_sdk.Ref_Basis{}
err := decode(b, &ref)
if err != nil {
panic("failed to decode basis to ref: " + err.Error())
}
return &ref
}
// Load a Basis from a protobuf message. This will only search
// against the resource id.
func (s *State) BasisFromProto(
b *vagrant_server.Basis, b *vagrant_server.Basis,
) (*vagrant_server.Basis, error) { ) (*Basis, error) {
var match *basisIndexRecord if b == nil {
return nil, ErrEmptyProtoArgument
// Start with the resource id first
if b.ResourceId != "" {
if raw, err := memTxn.First(
basisIndexTableName,
basisIndexIdIndexName,
b.ResourceId,
); raw != nil && err == nil {
match = raw.(*basisIndexRecord)
}
}
// Try the name next
if b.Name != "" && match == nil {
if raw, err := memTxn.First(
basisIndexTableName,
basisIndexNameIndexName,
b.Name,
); raw != nil && err == nil {
match = raw.(*basisIndexRecord)
}
}
// And finally the path
if b.Path != "" && match == nil {
if raw, err := memTxn.First(
basisIndexTableName,
basisIndexPathIndexName,
b.Path,
); raw != nil && err == nil {
match = raw.(*basisIndexRecord)
}
} }
if match == nil { basis, err := s.BasisFromProtoRef(
return nil, status.Errorf(codes.NotFound, "record not found for Basis") &vagrant_plugin_sdk.Ref_Basis{
} ResourceId: b.ResourceId,
},
return s.basisGet(dbTxn, memTxn, &vagrant_plugin_sdk.Ref_Basis{ )
ResourceId: match.Id,
})
}
func (s *State) basisPut(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
value *vagrant_server.Basis,
) (err error) {
s.log.Trace("storing basis", "basis", value)
if value.ResourceId == "" {
s.log.Trace("basis has no resource id, assuming new basis",
"basis", value)
if value.ResourceId, err = s.newResourceId(); err != nil {
s.log.Error("failed to create resource id for basis", "basis", value,
"error", err)
return
}
}
s.log.Trace("storing basis to db", "basis", value)
id := s.basisId(value)
b := dbTxn.Bucket(basisBucket)
if err = dbPut(b, id, value); err != nil {
s.log.Error("failed to store basis in db", "basis", value, "error", err)
return
}
s.log.Trace("indexing basis", "basis", value)
if err = s.basisIndexSet(memTxn, id, value); err != nil {
s.log.Error("failed to index basis", "basis", value, "error", err)
return
}
return
}
func (s *State) basisList(
memTxn *memdb.Txn,
) ([]*vagrant_plugin_sdk.Ref_Basis, error) {
iter, err := memTxn.Get(basisIndexTableName, basisIndexIdIndexName+"_prefix", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
var result []*vagrant_plugin_sdk.Ref_Basis return basis, nil
for {
next := iter.Next()
if next == nil {
break
}
idx := next.(*basisIndexRecord)
result = append(result, &vagrant_plugin_sdk.Ref_Basis{
ResourceId: idx.Id,
Name: idx.Name,
})
}
return result, nil
} }
func (s *State) basisDelete( // Load a Basis from a protobuf message. This will attempt to locate the
dbTxn *bolt.Tx, // basis using any unique field it can match.
memTxn *memdb.Txn, func (s *State) BasisFromProtoFuzzy(
ref *vagrant_plugin_sdk.Ref_Basis, b *vagrant_server.Basis,
) error { ) (*Basis, error) {
b, err := s.basisGet(dbTxn, memTxn, ref) if b == nil {
return nil, ErrEmptyProtoArgument
}
basis, err := s.BasisFromProtoRefFuzzy(
&vagrant_plugin_sdk.Ref_Basis{
ResourceId: b.ResourceId,
Name: b.Name,
Path: b.Path,
},
)
if err != nil { if err != nil {
if status.Code(err) == codes.NotFound { return nil, err
return nil
}
return err
} }
for _, p := range b.Projects { return basis, nil
if err := s.projectDelete(dbTxn, memTxn, p); err != nil { }
return err
// Load a Basis from a reference protobuf message
func (s *State) BasisFromProtoRef(
ref *vagrant_plugin_sdk.Ref_Basis,
) (*Basis, error) {
if ref == nil {
return nil, ErrEmptyProtoArgument
}
if ref.ResourceId == "" {
return nil, gorm.ErrRecordNotFound
}
var basis Basis
result := s.search().First(&basis, &Basis{ResourceId: &ref.ResourceId})
if result.Error != nil {
return nil, result.Error
}
return &basis, nil
}
func (s *State) BasisFromProtoRefFuzzy(
ref *vagrant_plugin_sdk.Ref_Basis,
) (*Basis, error) {
basis, err := s.BasisFromProtoRef(ref)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
if basis != nil {
return basis, nil
}
// If name and path are both empty, we can't search
if ref.Name == "" && ref.Path == "" {
return nil, gorm.ErrRecordNotFound
}
basis = &Basis{}
query := &Basis{}
if ref.Name != "" {
query.Name = &ref.Name
}
if ref.Path != "" {
query.Path = &ref.Path
}
result := s.search().First(basis, query)
if result.Error != nil {
return nil, result.Error
}
return basis, nil
}
// Get a basis record using a reference protobuf message.
func (s *State) BasisGet(
ref *vagrant_plugin_sdk.Ref_Basis,
) (*vagrant_server.Basis, error) {
b, err := s.BasisFromProtoRef(ref)
if err != nil {
return nil, lookupErrorToStatus("basis", err)
}
return b.ToProto(), nil
}
// Find a basis record using a protobuf message
func (s *State) BasisFind(
b *vagrant_server.Basis,
) (*vagrant_server.Basis, error) {
basis, err := s.BasisFromProtoFuzzy(b)
if err != nil {
return nil, lookupErrorToStatus("basis", err)
}
return basis.ToProto(), nil
}
// Store a basis record
func (s *State) BasisPut(
b *vagrant_server.Basis,
) (*vagrant_server.Basis, error) {
basis, err := s.BasisFromProto(b)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, lookupErrorToStatus("basis", err)
}
// Make sure we don't have a nil
if err != nil {
basis = &Basis{}
}
err = s.softDecode(b, basis)
if err != nil {
return nil, saveErrorToStatus("basis", err)
}
if b.Configuration != nil {
if basis.Vagrantfile != nil {
basis.Vagrantfile.UpdateFromProto(b.Configuration)
} else {
basis.Vagrantfile = s.VagrantfileFromProto(b.Configuration)
} }
} }
// Delete from bolt result := s.db.Save(basis)
if err := dbTxn.Bucket(basisBucket).Delete(s.basisId(b)); err != nil { if result.Error != nil {
return err return nil, saveErrorToStatus("basis", result.Error)
} }
// Delete from memdb
record := s.newBasisIndexRecord(b) return basis.ToProto(), nil
if err := memTxn.Delete(basisIndexTableName, record); err != nil { }
return err
// List all basis records
func (s *State) BasisList() ([]*Basis, error) {
var all []*Basis
result := s.search().Find(&all)
if result.Error != nil {
return nil, lookupErrorToStatus("basis", result.Error)
} }
return all, nil
}
// Delete a basis
func (s *State) BasisDelete(
b *vagrant_plugin_sdk.Ref_Basis,
) error {
basis, err := s.BasisFromProtoRef(b)
// If the record was not found, we return with no error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return nil
}
// If an unexpected error was encountered, return it
if err != nil {
return lookupErrorToStatus("basis", err)
}
result := s.db.Delete(basis)
if result.Error != nil {
return deleteErrorToStatus("basis", result.Error)
}
return nil return nil
} }
func (s *State) basisIndexSet(txn *memdb.Txn, id []byte, value *vagrant_server.Basis) error { var (
return txn.Insert(basisIndexTableName, s.newBasisIndexRecord(value)) _ scope = (*Basis)(nil)
}
func (s *State) basisIndexInit(dbTxn *bolt.Tx, memTxn *memdb.Txn) error {
bucket := dbTxn.Bucket(basisBucket)
return bucket.ForEach(func(k, v []byte) error {
var value vagrant_server.Basis
if err := proto.Unmarshal(v, &value); err != nil {
return err
}
if err := s.basisIndexSet(memTxn, k, &value); err != nil {
return err
}
return nil
})
}
func basisIndexSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: basisIndexTableName,
Indexes: map[string]*memdb.IndexSchema{
basisIndexIdIndexName: {
Name: basisIndexIdIndexName,
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Id",
Lowercase: false,
},
},
basisIndexNameIndexName: {
Name: basisIndexNameIndexName,
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Name",
Lowercase: true,
},
},
basisIndexPathIndexName: {
Name: basisIndexPathIndexName,
AllowMissing: true,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Path",
Lowercase: false,
},
},
},
}
}
const (
basisIndexIdIndexName = "id"
basisIndexNameIndexName = "name"
basisIndexPathIndexName = "path"
basisIndexTableName = "basis-index"
) )
type basisIndexRecord struct {
Id string
Name string
Path string
}
func (s *State) newBasisIndexRecord(b *vagrant_server.Basis) *basisIndexRecord {
return &basisIndexRecord{
Id: b.ResourceId,
Name: strings.ToLower(b.Name),
Path: b.Path,
}
}
func (s *State) newBasisIndexRecordByRef(ref *vagrant_plugin_sdk.Ref_Basis) *basisIndexRecord {
return &basisIndexRecord{
Id: ref.ResourceId,
Name: strings.ToLower(ref.Name),
}
}
func (s *State) basisId(b *vagrant_server.Basis) []byte {
return []byte(b.ResourceId)
}
func (s *State) basisIdByRef(ref *vagrant_plugin_sdk.Ref_Basis) []byte {
if ref == nil {
return []byte{}
}
return []byte(ref.ResourceId)
}

View File

@ -19,6 +19,42 @@ func TestBasis(t *testing.T) {
require.Error(err) require.Error(err)
}) })
t.Run("Put creates and sets resource ID", func(t *testing.T) {
require := require.New(t)
s := TestState(t)
defer s.Close()
testBasis := &vagrant_server.Basis{
Name: "test_name",
Path: "/User/test/test",
}
result, err := s.BasisPut(testBasis)
require.NoError(err)
require.NotEmpty(result.ResourceId)
})
t.Run("Put fails on duplicate name", func(t *testing.T) {
require := require.New(t)
s := TestState(t)
defer s.Close()
testBasis := &vagrant_server.Basis{
Name: "test_name",
Path: "/User/test/test",
}
// Set initial record
_, err := s.BasisPut(testBasis)
require.NoError(err)
// Attempt to set it again
_, err = s.BasisPut(testBasis)
require.Error(err)
})
t.Run("Put and Get", func(t *testing.T) { t.Run("Put and Get", func(t *testing.T) {
require := require.New(t) require := require.New(t)
@ -26,38 +62,23 @@ func TestBasis(t *testing.T) {
defer s.Close() defer s.Close()
testBasis := &vagrant_server.Basis{ testBasis := &vagrant_server.Basis{
ResourceId: "test", Name: "test_name",
Name: "test_name", Path: "/User/test/test",
Path: "/User/test/test",
}
testBasisRef := &vagrant_plugin_sdk.Ref_Basis{
ResourceId: "test",
Name: "test_name",
Path: "/User/test/test",
} }
// Set // Set
err := s.BasisPut(testBasis) result, err := s.BasisPut(testBasis)
require.NoError(err) require.NoError(err)
// Get full ref testBasisRef := &vagrant_plugin_sdk.Ref_Basis{
{ ResourceId: result.ResourceId,
resp, err := s.BasisGet(testBasisRef)
require.NoError(err)
require.NotNil(resp)
require.Equal(resp.Name, testBasis.Name)
} }
// Get by id // Get full ref
{ resp, err := s.BasisGet(testBasisRef)
resp, err := s.BasisGet(&vagrant_plugin_sdk.Ref_Basis{ require.NoError(err)
ResourceId: "test", require.NotNil(resp)
}) require.Equal(resp.Name, testBasis.Name)
require.NoError(err)
require.NotNil(resp)
require.Equal(resp.Name, testBasis.Name)
}
}) })
t.Run("Find", func(t *testing.T) { t.Run("Find", func(t *testing.T) {
@ -67,19 +88,18 @@ func TestBasis(t *testing.T) {
defer s.Close() defer s.Close()
testBasis := &vagrant_server.Basis{ testBasis := &vagrant_server.Basis{
ResourceId: "test", Name: "test_name",
Name: "test_name", Path: "/User/test/test",
Path: "/User/test/test",
} }
// Set // Set
err := s.BasisPut(testBasis) result, err := s.BasisPut(testBasis)
require.NoError(err) require.NoError(err)
// Find by resource id // Find by resource id
{ {
resp, err := s.BasisFind(&vagrant_server.Basis{ resp, err := s.BasisFind(&vagrant_server.Basis{
ResourceId: "test", ResourceId: result.ResourceId,
}) })
require.NoError(err) require.NoError(err)
require.NotNil(resp) require.NotNil(resp)
@ -114,9 +134,8 @@ func TestBasis(t *testing.T) {
defer s.Close() defer s.Close()
testBasis := &vagrant_server.Basis{ testBasis := &vagrant_server.Basis{
ResourceId: "test", Name: "test_name",
Name: "test_name", Path: "/User/test/test",
Path: "/User/test/test",
} }
testBasisRef := &vagrant_plugin_sdk.Ref_Basis{ResourceId: "test"} testBasisRef := &vagrant_plugin_sdk.Ref_Basis{ResourceId: "test"}
@ -126,8 +145,9 @@ func TestBasis(t *testing.T) {
require.NoError(err) require.NoError(err)
// Add basis // Add basis
err = s.BasisPut(testBasis) result, err := s.BasisPut(testBasis)
require.NoError(err) require.NoError(err)
testBasisRef.ResourceId = result.ResourceId
// No error when deleting basis // No error when deleting basis
err = s.BasisDelete(testBasisRef) err = s.BasisDelete(testBasisRef)
@ -145,17 +165,15 @@ func TestBasis(t *testing.T) {
defer s.Close() defer s.Close()
// Add basis' // Add basis'
err := s.BasisPut(&vagrant_server.Basis{ _, err := s.BasisPut(&vagrant_server.Basis{
ResourceId: "test", Name: "test_name",
Name: "test_name", Path: "/User/test/test",
Path: "/User/test/test",
}) })
require.NoError(err) require.NoError(err)
err = s.BasisPut(&vagrant_server.Basis{ _, err = s.BasisPut(&vagrant_server.Basis{
ResourceId: "test2", Name: "test_name2",
Name: "test_name2", Path: "/User/test/test2",
Path: "/User/test/test2",
}) })
require.NoError(err) require.NoError(err)

View File

@ -1,305 +1,361 @@
package state package state
import ( import (
"google.golang.org/protobuf/proto" "errors"
"github.com/hashicorp/go-memdb" "fmt"
"time"
"github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk" "github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
bolt "go.etcd.io/bbolt" "github.com/mitchellh/mapstructure"
"google.golang.org/grpc/codes" "gorm.io/gorm"
"google.golang.org/grpc/status"
) )
var boxBucket = []byte("box")
func init() { func init() {
dbBuckets = append(dbBuckets, boxBucket) models = append(models, &Box{})
dbIndexers = append(dbIndexers, (*State).boxIndexInit)
schemas = append(schemas, boxIndexSchema)
}
func (s *State) BoxList() ([]*vagrant_plugin_sdk.Ref_Box, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
return s.boxList(memTxn)
}
func (s *State) BoxDelete(ref *vagrant_plugin_sdk.Ref_Box) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) error {
return s.boxDelete(dbTxn, memTxn, ref)
})
if err == nil {
memTxn.Commit()
}
return err
}
func (s *State) BoxGet(ref *vagrant_plugin_sdk.Ref_Box) (*vagrant_server.Box, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
var result *vagrant_server.Box
err := s.db.View(func(dbTxn *bolt.Tx) error {
var err error
result, err = s.boxGet(dbTxn, memTxn, ref)
return err
})
return result, err
}
func (s *State) BoxPut(box *vagrant_server.Box) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) error {
return s.boxPut(dbTxn, memTxn, box)
})
if err == nil {
memTxn.Commit()
}
return err
}
func (s *State) BoxFind(b *vagrant_plugin_sdk.Ref_Box) (*vagrant_server.Box, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
var result *vagrant_server.Box
err := s.db.View(func(dbTxn *bolt.Tx) error {
var err error
result, err = s.boxFind(dbTxn, memTxn, b)
return err
})
return result, err
}
func (s *State) boxList(
memTxn *memdb.Txn,
) (r []*vagrant_plugin_sdk.Ref_Box, err error) {
iter, err := memTxn.Get(boxIndexTableName, boxIndexIdIndexName+"_prefix", "")
if err != nil {
return nil, err
}
var result []*vagrant_plugin_sdk.Ref_Box
for {
next := iter.Next()
if next == nil {
break
}
result = append(result, &vagrant_plugin_sdk.Ref_Box{
ResourceId: next.(*boxIndexRecord).Id,
Name: next.(*boxIndexRecord).Name,
Version: next.(*boxIndexRecord).Version,
Provider: next.(*boxIndexRecord).Provider,
})
}
return result, nil
}
func (s *State) boxDelete(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
ref *vagrant_plugin_sdk.Ref_Box,
) (err error) {
b, err := s.boxGet(dbTxn, memTxn, ref)
if err != nil {
if status.Code(err) == codes.NotFound {
return nil
}
return err
}
// Delete the box
if err = dbTxn.Bucket(boxBucket).Delete(s.boxId(b)); err != nil {
return
}
if err = memTxn.Delete(boxIndexTableName, s.newBoxIndexRecord(b)); err != nil {
return
}
return
}
func (s *State) boxGet(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
ref *vagrant_plugin_sdk.Ref_Box,
) (r *vagrant_server.Box, err error) {
var result vagrant_server.Box
b := dbTxn.Bucket(boxBucket)
return &result, dbGet(b, s.boxIdByRef(ref), &result)
}
func (s *State) boxPut(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
value *vagrant_server.Box,
) (err error) {
id := s.boxId(value)
b := dbTxn.Bucket(boxBucket)
if err = dbPut(b, id, value); err != nil {
s.log.Error("failed to store box in db", "box", value, "error", err)
return
}
s.log.Trace("indexing box", "box", value)
if err = s.boxIndexSet(memTxn, id, value); err != nil {
s.log.Error("failed to index box", "box", value, "error", err)
return
}
return
}
func (s *State) boxFind(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
ref *vagrant_plugin_sdk.Ref_Box,
) (r *vagrant_server.Box, err error) {
var match *boxIndexRecord
highestVersion, _ := version.NewVersion("0.0.0")
req := s.newBoxIndexRecordByRef(ref)
// Get the name first
if req.Name != "" {
raw, err := memTxn.Get(
boxIndexTableName,
boxIndexNameIndexName,
req.Name,
)
if err != nil {
return nil, err
}
if req.Version == "" {
req.Version = ">= 0"
}
versionConstraint, err := version.NewConstraint(req.Version)
if err != nil {
return nil, err
}
for e := raw.Next(); e != nil; e = raw.Next() {
boxIndexEntry := e.(*boxIndexRecord)
if req.Version != "" {
boxVersion, _ := version.NewVersion(boxIndexEntry.Version)
if !versionConstraint.Check(boxVersion) {
continue
}
}
if req.Provider != "" {
if boxIndexEntry.Provider != req.Provider {
continue
}
}
// Set first match
if match == nil {
match = boxIndexEntry
}
v, _ := version.NewVersion(boxIndexEntry.Version)
if v.GreaterThan(highestVersion) {
highestVersion = v
match = boxIndexEntry
}
}
if match != nil {
return s.boxGet(dbTxn, memTxn, &vagrant_plugin_sdk.Ref_Box{
ResourceId: match.Id,
})
}
}
return
} }
const ( const (
boxIndexIdIndexName = "id" DEFAULT_BOX_VERSION = "0.0.0"
boxIndexNameIndexName = "name" DEFAULT_BOX_CONSTRAINT = "> 0"
boxIndexTableName = "box-index"
) )
type boxIndexRecord struct { type Box struct {
Id string // Resource ID gorm.Model
Name string // Box Name
Version string // Box Version Directory *string `gorm:"not null"`
Provider string // Box Provider LastUpdate *time.Time `gorm:"autoUpdateTime"`
Metadata *ProtoValue
MetadataUrl *string
Name *string `gorm:"uniqueIndex:idx_nameverprov;not null"`
Provider *string `gorm:"uniqueIndex:idx_nameverprov;not null"`
ResourceId *string `gorm:"<-:create;uniqueIndex;not null"`
Version *string `gorm:"uniqueIndex:idx_nameverprov;not null"`
} }
func (s *State) newBoxIndexRecord(b *vagrant_server.Box) *boxIndexRecord { func (b *Box) BeforeSave(tx *gorm.DB) error {
id := b.Name + "-" + b.Version + "-" + b.Provider if b.ResourceId == nil {
return &boxIndexRecord{ if err := b.setId(); err != nil {
Id: id, return err
Name: b.Name, }
Version: b.Version,
Provider: b.Provider,
} }
// If version is not set, default it to 0
if b.Version == nil || *b.Version == "0" {
v := DEFAULT_BOX_VERSION
b.Version = &v
}
if err := b.Validate(tx); err != nil {
return err
}
return nil
} }
func (s *State) boxIndexSet(txn *memdb.Txn, id []byte, value *vagrant_server.Box) error { func (b *Box) setId() error {
return txn.Insert(boxIndexTableName, s.newBoxIndexRecord(value)) id, err := server.Id()
if err != nil {
return err
}
b.ResourceId = &id
return nil
} }
func (s *State) boxIndexInit(dbTxn *bolt.Tx, memTxn *memdb.Txn) error { func (b *Box) Validate(tx *gorm.DB) error {
bucket := dbTxn.Bucket(boxBucket) err := validation.ValidateStruct(b,
return bucket.ForEach(func(k, v []byte) error { validation.Field(&b.Directory, validation.Required),
var value vagrant_server.Box validation.Field(&b.Name, validation.Required),
if err := proto.Unmarshal(v, &value); err != nil { validation.Field(&b.Provider, validation.Required),
return err validation.Field(&b.ResourceId,
} validation.Required,
if err := s.boxIndexSet(memTxn, k, &value); err != nil { validation.By(
return err checkUnique(
} tx.Model(&Box{}).
Where(&Box{ResourceId: b.ResourceId}).
Not(&Box{Model: gorm.Model{ID: b.ID}}),
),
),
),
validation.Field(&b.Version,
validation.Required,
is.Semver,
),
)
return nil if err != nil {
}) return err
}
err = validation.Validate(b,
validation.By(
checkUnique(
tx.Model(&Box{}).
Where(&Box{Name: b.Name, Provider: b.Provider, Version: b.Version}).
Not(&Box{Model: gorm.Model{ID: b.ID}}),
),
),
)
if err != nil {
return fmt.Errorf("name, provider and version %s", err)
}
return nil
} }
func boxIndexSchema() *memdb.TableSchema { func (b *Box) ToProto() *vagrant_server.Box {
return &memdb.TableSchema{ var p vagrant_server.Box
Name: boxIndexTableName, err := decode(b, &p)
Indexes: map[string]*memdb.IndexSchema{ if err != nil {
boxIndexIdIndexName: { panic(fmt.Sprintf("failed to decode box: " + err.Error()))
Name: boxIndexIdIndexName, }
AllowMissing: false,
Unique: true, return &p
Indexer: &memdb.StringFieldIndex{ }
Field: "Id",
Lowercase: true, func (b *Box) ToProtoRef() *vagrant_plugin_sdk.Ref_Box {
}, var p vagrant_plugin_sdk.Ref_Box
}, err := decode(b, &p)
boxIndexNameIndexName: { if err != nil {
Name: boxIndexNameIndexName, panic(fmt.Sprintf("failed to decode box ref: " + err.Error()))
AllowMissing: false, }
Unique: false,
Indexer: &memdb.StringFieldIndex{ return &p
Field: "Name", }
Lowercase: true,
}, func (s *State) BoxFromProtoRef(
}, b *vagrant_plugin_sdk.Ref_Box,
) (*Box, error) {
if b == nil {
return nil, ErrEmptyProtoArgument
}
if b.ResourceId == "" {
return nil, gorm.ErrRecordNotFound
}
var box Box
result := s.search().First(&box, &Box{ResourceId: &b.ResourceId})
if result.Error != nil {
return nil, result.Error
}
return &box, nil
}
func (s *State) BoxFromProtoRefFuzzy(
b *vagrant_plugin_sdk.Ref_Box,
) (*Box, error) {
box, err := s.BoxFromProtoRef(b)
if err == nil {
return box, nil
}
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
if b.Name == "" || b.Provider == "" || b.Version == "" {
return nil, gorm.ErrRecordNotFound
}
box = &Box{}
result := s.search().First(box,
&Box{
Name: &b.Name,
Provider: &b.Provider,
Version: &b.Version,
}, },
)
if result.Error != nil {
return nil, result.Error
} }
return box, nil
} }
func (s *State) newBoxIndexRecordByRef(ref *vagrant_plugin_sdk.Ref_Box) *boxIndexRecord { func (s *State) BoxFromProto(
return &boxIndexRecord{ b *vagrant_server.Box,
Id: ref.ResourceId, ) (*Box, error) {
Name: ref.Name, return s.BoxFromProtoRef(
Version: ref.Version, &vagrant_plugin_sdk.Ref_Box{
Provider: ref.Provider, ResourceId: b.ResourceId,
},
)
}
func (s *State) BoxFromProtoFuzzy(
b *vagrant_server.Box,
) (*Box, error) {
return s.BoxFromProtoRefFuzzy(
&vagrant_plugin_sdk.Ref_Box{
Name: b.Name,
Provider: b.Provider,
ResourceId: b.ResourceId,
Version: b.Version,
},
)
}
func (s *State) BoxList() ([]*vagrant_plugin_sdk.Ref_Box, error) {
var boxes []Box
result := s.db.Find(&boxes)
if result.Error != nil {
return nil, lookupErrorToStatus("boxes", result.Error)
} }
refs := make([]*vagrant_plugin_sdk.Ref_Box, len(boxes))
for i, b := range boxes {
refs[i] = b.ToProtoRef()
}
return refs, nil
} }
func (s *State) boxId(b *vagrant_server.Box) []byte { func (s *State) BoxDelete(
return []byte(b.Id) b *vagrant_plugin_sdk.Ref_Box,
) error {
box, err := s.BoxFromProtoRef(b)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return nil
}
if err != nil {
return deleteErrorToStatus("box", err)
}
result := s.db.Delete(box)
if result.Error != nil {
return deleteErrorToStatus("box", result.Error)
}
return nil
} }
func (s *State) boxIdByRef(b *vagrant_plugin_sdk.Ref_Box) []byte { func (s *State) BoxGet(
return []byte(b.ResourceId) b *vagrant_plugin_sdk.Ref_Box,
) (*vagrant_server.Box, error) {
box, err := s.BoxFromProtoRef(b)
if err != nil {
return nil, lookupErrorToStatus("box", err)
}
return box.ToProto(), nil
}
func (s *State) BoxPut(b *vagrant_server.Box) error {
box, err := s.BoxFromProto(b)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return lookupErrorToStatus("box", err)
}
if err != nil {
box = &Box{}
}
err = s.softDecode(b, box)
if err != nil {
return saveErrorToStatus("box", err)
}
result := s.db.Save(box)
if result.Error != nil {
return saveErrorToStatus("box", result.Error)
}
return nil
}
func (s *State) BoxFind(
ref *vagrant_plugin_sdk.Ref_Box,
) (*vagrant_server.Box, error) {
b := &vagrant_plugin_sdk.Ref_Box{}
if err := mapstructure.Decode(ref, b); err != nil {
return nil, lookupErrorToStatus("box", err)
}
if b.ResourceId != "" {
box, err := s.BoxFromProtoRef(b)
if err != nil {
return nil, lookupErrorToStatus("box", err)
}
return box.ToProto(), nil
}
// If no name is given, we error immediately
if b.Name == "" {
return nil, lookupErrorToStatus("box", fmt.Errorf("no name given for box lookup"))
}
// If no provider is given, we error immediately
if b.Provider == "" {
return nil, lookupErrorToStatus("box", fmt.Errorf("no provider given for box lookup"))
}
// If the version is set to 0, mark it as default
if b.Version == "0" {
b.Version = DEFAULT_BOX_VERSION
}
// If we are provided an explicit version, just do a direct lookup
if _, err := version.NewVersion(b.Version); err == nil {
box, err := s.BoxFromProtoRefFuzzy(b)
if err != nil {
return nil, lookupErrorToStatus("box", err)
}
return box.ToProto(), nil
}
var boxes []Box
result := s.search().Find(&boxes,
&Box{
Name: &b.Name,
Provider: &b.Provider,
},
)
if result.Error != nil {
return nil, lookupErrorToStatus("box", result.Error)
}
// If we found no boxes, return a not found error
if len(boxes) < 1 {
return nil, lookupErrorToStatus("box", gorm.ErrRecordNotFound)
}
// If we have no version value set, apply the default
// version constraint
if b.Version == "" {
b.Version = DEFAULT_BOX_CONSTRAINT
}
var match *Box
highestVersion, _ := version.NewVersion("0.0.0")
versionConstraint, err := version.NewConstraint(b.Version)
if err != nil {
return nil, lookupErrorToStatus("box", err)
}
for _, box := range boxes {
boxVersion, err := version.NewVersion(*box.Version)
if err != nil {
return nil, lookupErrorToStatus("box", err)
}
if !versionConstraint.Check(boxVersion) {
continue
}
if boxVersion.GreaterThan(highestVersion) {
match = &box
highestVersion = boxVersion
}
}
if match != nil {
return match.ToProto(), nil
}
return nil, lookupErrorToStatus("box", gorm.ErrRecordNotFound)
} }

View File

@ -26,10 +26,19 @@ func TestBox(t *testing.T) {
defer s.Close() defer s.Close()
testBox := &vagrant_server.Box{ testBox := &vagrant_server.Box{
Id: "qwerwasdf", ResourceId: "qwerwasdf",
Name: "hashicorp/bionic", Directory: "/directory",
Version: "1.2.3", Name: "hashicorp/bionic",
Provider: "virtualbox", Version: "1.2.3",
Provider: "virtualbox",
}
testBox2 := &vagrant_server.Box{
ResourceId: "qwerwasdf-2",
Directory: "/directory-2",
Name: "hashicorp/bionic",
Version: "1.2.4",
Provider: "virtualbox",
} }
testBoxRef := &vagrant_plugin_sdk.Ref_Box{ testBoxRef := &vagrant_plugin_sdk.Ref_Box{
@ -39,9 +48,18 @@ func TestBox(t *testing.T) {
Provider: "virtualbox", Provider: "virtualbox",
} }
testBoxRef2 := &vagrant_plugin_sdk.Ref_Box{
ResourceId: "qwerwasdf-2",
Name: "hashicorp/bionic",
Version: "1.2.4",
Provider: "virtualbox",
}
// Set // Set
err := s.BoxPut(testBox) err := s.BoxPut(testBox)
require.NoError(err) require.NoError(err)
err = s.BoxPut(testBox2)
require.NoError(err)
// Get full ref // Get full ref
{ {
@ -49,6 +67,11 @@ func TestBox(t *testing.T) {
require.NoError(err) require.NoError(err)
require.NotNil(resp) require.NotNil(resp)
require.Equal(resp.Name, testBox.Name) require.Equal(resp.Name, testBox.Name)
resp, err = s.BoxGet(testBoxRef2)
require.NoError(err)
require.NotNil(resp)
require.Equal(resp.Name, testBox2.Name)
} }
// Get by id // Get by id
@ -69,10 +92,11 @@ func TestBox(t *testing.T) {
defer s.Close() defer s.Close()
testBox := &vagrant_server.Box{ testBox := &vagrant_server.Box{
Id: "qwerwasdf", ResourceId: "qwerwasdf",
Name: "hashicorp/bionic", Directory: "/directory",
Version: "1.2.3", Name: "hashicorp/bionic",
Provider: "virtualbox", Version: "1.2.3",
Provider: "virtualbox",
} }
testBoxRef := &vagrant_plugin_sdk.Ref_Box{ testBoxRef := &vagrant_plugin_sdk.Ref_Box{
@ -98,18 +122,20 @@ func TestBox(t *testing.T) {
defer s.Close() defer s.Close()
err := s.BoxPut(&vagrant_server.Box{ err := s.BoxPut(&vagrant_server.Box{
Id: "qwerwasdf", ResourceId: "qwerwasdf",
Name: "hashicorp/bionic", Directory: "/directory",
Version: "1.2.3", Name: "hashicorp/bionic",
Provider: "virtualbox", Version: "1.2.3",
Provider: "virtualbox",
}) })
require.NoError(err) require.NoError(err)
err = s.BoxPut(&vagrant_server.Box{ err = s.BoxPut(&vagrant_server.Box{
Id: "rrbrwasdf", ResourceId: "rrbrwasdf",
Name: "hashicorp/bionic", Directory: "/other-directory",
Version: "1.2.4", Name: "hashicorp/bionic",
Provider: "virtualbox", Version: "1.2.4",
Provider: "virtualbox",
}) })
require.NoError(err) require.NoError(err)
@ -125,44 +151,40 @@ func TestBox(t *testing.T) {
defer s.Close() defer s.Close()
err := s.BoxPut(&vagrant_server.Box{ err := s.BoxPut(&vagrant_server.Box{
Id: "hashicorp/bionic-1.2.3-virtualbox", ResourceId: "hashicorp/bionic-1.2.3-virtualbox",
Name: "hashicorp/bionic", Directory: "/directory",
Version: "1.2.3", Name: "hashicorp/bionic",
Provider: "virtualbox", Version: "1.2.3",
Provider: "virtualbox",
}) })
require.NoError(err) require.NoError(err)
err = s.BoxPut(&vagrant_server.Box{ err = s.BoxPut(&vagrant_server.Box{
Id: "hashicorp/bionic-1.2.4-virtualbox", ResourceId: "hashicorp/bionic-1.2.4-virtualbox",
Name: "hashicorp/bionic", Directory: "/other-directory",
Version: "1.2.4", Name: "hashicorp/bionic",
Provider: "virtualbox", Version: "1.2.4",
Provider: "virtualbox",
}) })
require.NoError(err) require.NoError(err)
err = s.BoxPut(&vagrant_server.Box{ err = s.BoxPut(&vagrant_server.Box{
Id: "box-0-virtualbox", ResourceId: "box-0-virtualbox",
Name: "box", Directory: "/another-directory",
Version: "0", Name: "box",
Provider: "virtualbox", Version: "0",
Provider: "virtualbox",
}) })
require.NoError(err) require.NoError(err)
b, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{ b, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
Name: "hashicorp/bionic", Name: "hashicorp/bionic",
Provider: "virtualbox",
}) })
require.NoError(err) require.NoError(err)
require.Equal(b.Name, "hashicorp/bionic") require.Equal(b.Name, "hashicorp/bionic")
require.Equal(b.Version, "1.2.4") require.Equal(b.Version, "1.2.4")
b2, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
Name: "hashicorp/bionic",
Version: "1.2.3",
})
require.NoError(err)
require.Equal(b2.Name, "hashicorp/bionic")
require.Equal(b2.Version, "1.2.3")
b3, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{ b3, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
Name: "hashicorp/bionic", Name: "hashicorp/bionic",
Version: "1.2.3", Version: "1.2.3",
@ -178,7 +200,7 @@ func TestBox(t *testing.T) {
Version: "1.2.3", Version: "1.2.3",
Provider: "dontexist", Provider: "dontexist",
}) })
require.NoError(err) require.Error(err)
require.Nil(b4) require.Nil(b4)
b5, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{ b5, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
@ -186,32 +208,34 @@ func TestBox(t *testing.T) {
Version: "9.9.9", Version: "9.9.9",
Provider: "virtualbox", Provider: "virtualbox",
}) })
require.NoError(err) require.Error(err)
require.Nil(b5) require.Nil(b5)
b6, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{ b6, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
Version: "1.2.3", Version: "1.2.3",
}) })
require.NoError(err) require.Error(err)
require.Nil(b6) require.Nil(b6)
b7, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{ b7, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
Name: "dontexist", Name: "dontexist",
}) })
require.NoError(err) require.Error(err)
require.Nil(b7) require.Nil(b7)
b8, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{ b8, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
Name: "hashicorp/bionic", Name: "hashicorp/bionic",
Version: "~> 1.2", Provider: "virtualbox",
Version: "~> 1.2",
}) })
require.NoError(err) require.NoError(err)
require.Equal(b8.Name, "hashicorp/bionic") require.Equal(b8.Name, "hashicorp/bionic")
require.Equal(b8.Version, "1.2.4") require.Equal(b8.Version, "1.2.4")
b9, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{ b9, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
Name: "hashicorp/bionic", Name: "hashicorp/bionic",
Version: "> 1.0, < 3.0", Provider: "virtualbox",
Version: "> 1.0, < 3.0",
}) })
require.NoError(err) require.NoError(err)
require.Equal(b9.Name, "hashicorp/bionic") require.Equal(b9.Name, "hashicorp/bionic")
@ -221,15 +245,16 @@ func TestBox(t *testing.T) {
Name: "hashicorp/bionic", Name: "hashicorp/bionic",
Version: "< 1.0", Version: "< 1.0",
}) })
require.NoError(err) require.Error(err)
require.Nil(b10) require.Nil(b10)
b11, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{ b11, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
Name: "box", Name: "box",
Version: "0", Version: "0",
Provider: "virtualbox",
}) })
require.NoError(err) require.NoError(err)
require.Equal(b11.Name, "box") require.Equal(b11.Name, "box")
require.Equal(b11.Version, "0") require.Equal(b11.Version, "0.0.0")
}) })
} }

View File

@ -0,0 +1,70 @@
package state
import (
"github.com/hashicorp/vagrant-plugin-sdk/component"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
"gorm.io/gorm"
)
type Component struct {
gorm.Model
Name string `gorm:"uniqueIndex:idx_stname"`
ServerAddr string `gorm:"uniqueIndex:idx_stname"`
Type component.Type `gorm:"uniqueIndex:idx_stname"`
Runners []*Runner `gorm:"many2many:runner_components"`
}
func init() {
models = append(models, &Component{})
}
func (c *Component) ToProtoRef() *vagrant_server.Ref_Component {
if c == nil {
return nil
}
return &vagrant_server.Ref_Component{
Type: vagrant_server.Component_Type(c.Type),
Name: c.Name,
}
}
func (c *Component) ToProto() *vagrant_server.Component {
if c == nil {
return nil
}
return &vagrant_server.Component{
Type: vagrant_server.Component_Type(c.Type),
Name: c.Name,
ServerAddr: c.ServerAddr,
}
}
func (s *State) ComponentFromProto(p *vagrant_server.Component) (*Component, error) {
var c Component
result := s.db.First(&c, &Component{
Name: p.Name,
ServerAddr: p.ServerAddr,
Type: component.Type(p.Type),
})
if result.Error == nil {
return &c, nil
}
if result.Error == gorm.ErrRecordNotFound {
c.Name = p.Name
c.ServerAddr = p.ServerAddr
c.Type = component.Type(p.Type)
result = s.db.Save(&c)
if result.Error != nil {
return nil, result.Error
}
return &c, nil
}
return nil, result.Error
}

View File

@ -1,40 +1,72 @@
package state package state
// TODO(spox): When dealing with the scopes on the configvar protos,
// we need to do lookups + fillins to populate parents so we index
// them correctly in memory and can properly do lookups
import ( import (
"errors"
"fmt" "fmt"
"sort" "sort"
"google.golang.org/protobuf/proto"
"github.com/hashicorp/go-memdb" "github.com/hashicorp/go-memdb"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk" "github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
bolt "go.etcd.io/bbolt"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
serversort "github.com/hashicorp/vagrant/internal/server/sort" serversort "github.com/hashicorp/vagrant/internal/server/sort"
"gorm.io/gorm"
) )
var configBucket = []byte("config") type Config struct {
gorm.Model
Cid *string `gorm:"uniqueIndex"`
Name string
Scope *ProtoValue // TODO(spox): polymorphic needs to allow for runner
Value string
}
func init() { func init() {
dbBuckets = append(dbBuckets, configBucket) models = append(models, &Config{})
dbIndexers = append(dbIndexers, (*State).configIndexInit) dbIndexers = append(dbIndexers, (*State).configIndexInit)
schemas = append(schemas, configIndexSchema) schemas = append(schemas, configIndexSchema)
} }
func (c *Config) ToProto() *vagrant_server.ConfigVar {
if c == nil {
return nil
}
var config vagrant_server.ConfigVar
if err := decode(c, &config); err != nil {
panic("failed to decode config: " + err.Error())
}
return &config
}
func (s *State) ConfigFromProto(p *vagrant_server.ConfigVar) (*Config, error) {
var c Config
cid := string(s.configVarId(p))
result := s.db.First(&c, &Config{Cid: &cid})
if result.Error != nil {
return nil, result.Error
}
return &c, nil
}
// ConfigSet writes a configuration variable to the data store. // ConfigSet writes a configuration variable to the data store.
func (s *State) ConfigSet(vs ...*vagrant_server.ConfigVar) error { func (s *State) ConfigSet(vs ...*vagrant_server.ConfigVar) error {
memTxn := s.inmem.Txn(true) memTxn := s.inmem.Txn(true)
defer memTxn.Abort() defer memTxn.Abort()
var err error
err := s.db.Update(func(dbTxn *bolt.Tx) error { for _, v := range vs {
for _, v := range vs { if err := s.configSet(memTxn, v); err != nil {
if err := s.configSet(dbTxn, memTxn, v); err != nil { return err
return err
}
} }
}
return nil
})
if err == nil { if err == nil {
memTxn.Commit() memTxn.Commit()
} }
@ -53,63 +85,91 @@ func (s *State) ConfigGetWatch(req *vagrant_server.ConfigGetRequest, ws memdb.Wa
memTxn := s.inmem.Txn(false) memTxn := s.inmem.Txn(false)
defer memTxn.Abort() defer memTxn.Abort()
var result []*vagrant_server.ConfigVar return s.configGetMerged(memTxn, ws, req)
err := s.db.View(func(dbTxn *bolt.Tx) error {
var err error
result, err = s.configGetMerged(dbTxn, memTxn, ws, req)
return err
})
return result, err
} }
func (s *State) configSet( func (s *State) configSet(
dbTxn *bolt.Tx,
memTxn *memdb.Txn, memTxn *memdb.Txn,
value *vagrant_server.ConfigVar, value *vagrant_server.ConfigVar,
) error { ) error {
id := s.configVarId(value) id := s.configVarId(value)
// Get the global bucket and write the value to it. // Persist the configuration in the db
b := dbTxn.Bucket(configBucket) c, err := s.ConfigFromProto(value)
if value.Value == "" { if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
if err := b.Delete(id); err != nil { return err
return err }
}
} else { if err != nil {
if err := dbPut(b, id, value); err != nil { cid := string(id)
return err c = &Config{Cid: &cid}
} }
if err = s.softDecode(value, c); err != nil {
return saveErrorToStatus("config", err)
}
result := s.db.Save(c)
if result.Error != nil {
return saveErrorToStatus("config", result.Error)
} }
// Create our index value and write that. // Create our index value and write that.
return s.configIndexSet(memTxn, id, value) if err = s.configIndexSet(memTxn, id, value); err != nil {
return saveErrorToStatus("config", err)
}
return nil
} }
func (s *State) configGetMerged( func (s *State) configGetMerged(
dbTxn *bolt.Tx,
memTxn *memdb.Txn, memTxn *memdb.Txn,
ws memdb.WatchSet, ws memdb.WatchSet,
req *vagrant_server.ConfigGetRequest, req *vagrant_server.ConfigGetRequest,
) ([]*vagrant_server.ConfigVar, error) { ) ([]*vagrant_server.ConfigVar, error) {
var mergeSet [][]*vagrant_server.ConfigVar var mergeSet [][]*vagrant_server.ConfigVar
switch scope := req.Scope.(type) { switch scope := req.Scope.(type) {
case *vagrant_server.ConfigGetRequest_Basis:
// For basis scope, we just return the basis scoped values
return s.configGetExact(memTxn, ws, scope.Basis, req.Prefix)
case *vagrant_server.ConfigGetRequest_Project: case *vagrant_server.ConfigGetRequest_Project:
// For project scope, we just return the project scoped values. // For project scope, we collect project and basis values
return s.configGetExact(dbTxn, memTxn, ws, scope.Project, req.Prefix) m, err := s.configGetExact(memTxn, ws, scope.Project.Basis, req.Prefix)
if err != nil {
// TODO(spox): this should be a "something" (do we allow config for any machine,project,basis?) return nil, err
// case *vagrant_server.ConfigGetRequest_Application: }
mergeSet = append(mergeSet, m)
m, err = s.configGetExact(memTxn, ws, scope.Project, req.Prefix)
if err != nil {
return nil, err
}
mergeSet = append(mergeSet, m)
case *vagrant_server.ConfigGetRequest_Target:
// For project scope, we collect project and basis values
m, err := s.configGetExact(memTxn, ws, scope.Target.Project.Basis, req.Prefix)
if err != nil {
return nil, err
}
mergeSet = append(mergeSet, m)
m, err = s.configGetExact(memTxn, ws, scope.Target.Project, req.Prefix)
if err != nil {
return nil, err
}
mergeSet = append(mergeSet, m)
m, err = s.configGetExact(memTxn, ws, scope.Target, req.Prefix)
if err != nil {
return nil, err
}
mergeSet = append(mergeSet, m)
case *vagrant_server.ConfigGetRequest_Runner: case *vagrant_server.ConfigGetRequest_Runner:
var err error var err error
mergeSet, err = s.configGetRunner(dbTxn, memTxn, ws, scope.Runner, req.Prefix) mergeSet, err = s.configGetRunner(memTxn, ws, scope.Runner, req.Prefix)
if err != nil { if err != nil {
return nil, err return nil, err
} }
default: default:
panic("unknown scope") return nil, fmt.Errorf("unknown scope type provided (%T)", req.Scope)
} }
// Merge our merge set // Merge our merge set
@ -132,34 +192,54 @@ func (s *State) configGetMerged(
// configGetExact returns the list of config variables for a scope // configGetExact returns the list of config variables for a scope
// exactly. By "exactly" we mean without any merging logic: if you request // exactly. By "exactly" we mean without any merging logic: if you request
// app-scoped variables, you'll get app-scoped variables. If a project-scoped // target-scoped variables, you'll get target-scoped variables. If a project-scoped
// variable matches, it will not be merged in. // variable matches, it will not be merged in.
func (s *State) configGetExact( func (s *State) configGetExact(
dbTxn *bolt.Tx,
memTxn *memdb.Txn, memTxn *memdb.Txn,
ws memdb.WatchSet, ws memdb.WatchSet,
ref interface{}, // should be one of the *vagrant_server.Ref_ values. ref interface{}, // should be one of the *vagrant_plugin_sdk.Ref_ or *vagrant_server.Ref_ values.
prefix string, prefix string,
) ([]*vagrant_server.ConfigVar, error) { ) ([]*vagrant_server.ConfigVar, error) {
// We have to get the correct iterator based on the scope. We check the // We have to get the correct iterator based on the scope. We check the
// scope and use the proper index to get the iterator here. // scope and use the proper index to get the iterator here.
var iter memdb.ResultIterator var iter memdb.ResultIterator
switch ref := ref.(type) { var err error
switch v := ref.(type) {
case *vagrant_plugin_sdk.Ref_Project: case *vagrant_plugin_sdk.Ref_Basis:
var err error
iter, err = memTxn.Get( iter, err = memTxn.Get(
configIndexTableName, configIndexTableName,
configIndexProjectIndexName+"_prefix", configIndexIdIndexName+"_prefix", // Enable a prefix match on lookup
ref.ResourceId, fmt.Sprintf("%s/%s", v.ResourceId, prefix),
prefix, )
if err != nil {
return nil, err
}
case *vagrant_plugin_sdk.Ref_Project:
iter, err = memTxn.Get(
configIndexTableName,
configIndexIdIndexName+"_prefix", // Enable a prefix match on lookup
fmt.Sprintf("%s/%s/%s", v.Basis.ResourceId, v.ResourceId, prefix),
)
if err != nil {
return nil, err
}
case *vagrant_plugin_sdk.Ref_Target:
iter, err = memTxn.Get(
configIndexTableName,
configIndexIdIndexName+"_prefix", // Enable a prefix match on lookup
fmt.Sprintf("%s/%s/%s/%s",
v.Project.Basis.ResourceId,
v.Project.ResourceId,
v.ResourceId,
prefix,
),
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
default: default:
panic("unknown scope") return nil, fmt.Errorf("unknown scope type provided (%T)", ref)
} }
// Add to our watchset // Add to our watchset
@ -167,20 +247,20 @@ func (s *State) configGetExact(
// Go through the iterator and accumulate the results // Go through the iterator and accumulate the results
var result []*vagrant_server.ConfigVar var result []*vagrant_server.ConfigVar
b := dbTxn.Bucket(configBucket)
for { for {
current := iter.Next() current := iter.Next()
if current == nil { if current == nil {
break break
} }
var value vagrant_server.ConfigVar var value Config
record := current.(*configIndexRecord) record := current.(*configIndexRecord)
if err := dbGet(b, []byte(record.Id), &value); err != nil { res := s.db.First(&value, &Config{Cid: &record.Id})
return nil, err if res.Error != nil {
return nil, res.Error
} }
result = append(result, value.ToProto())
result = append(result, &value)
} }
return result, nil return result, nil
@ -188,7 +268,6 @@ func (s *State) configGetExact(
// configGetRunner gets the config vars for a runner. // configGetRunner gets the config vars for a runner.
func (s *State) configGetRunner( func (s *State) configGetRunner(
dbTxn *bolt.Tx,
memTxn *memdb.Txn, memTxn *memdb.Txn,
ws memdb.WatchSet, ws memdb.WatchSet,
req *vagrant_server.Ref_RunnerId, req *vagrant_server.Ref_RunnerId,
@ -196,7 +275,7 @@ func (s *State) configGetRunner(
) ([][]*vagrant_server.ConfigVar, error) { ) ([][]*vagrant_server.ConfigVar, error) {
iter, err := memTxn.Get( iter, err := memTxn.Get(
configIndexTableName, configIndexTableName,
configIndexRunnerIndexName+"_prefix", configIndexRunnerIndexName+"_prefix", // Enable a prefix match on lookup
true, true,
prefix, prefix,
) )
@ -214,8 +293,6 @@ func (s *State) configGetRunner(
idxId = 1 idxId = 1
) )
// Go through the iterator and accumulate the results
b := dbTxn.Bucket(configBucket)
for { for {
current := iter.Next() current := iter.Next()
if current == nil { if current == nil {
@ -240,12 +317,13 @@ func (s *State) configGetRunner(
return nil, fmt.Errorf("config has unknown target type: %T", record.RunnerRef.Target) return nil, fmt.Errorf("config has unknown target type: %T", record.RunnerRef.Target)
} }
var value vagrant_server.ConfigVar var value Config
if err := dbGet(b, []byte(record.Id), &value); err != nil { res := s.db.First(&value, &Config{Cid: &record.Id})
return nil, err if res.Error != nil {
return nil, res.Error
} }
result[idx] = append(result[idx], &value) result[idx] = append(result[idx], value.ToProto())
} }
return result, nil return result, nil
@ -253,29 +331,29 @@ func (s *State) configGetRunner(
// configIndexSet writes an index record for a single config var. // configIndexSet writes an index record for a single config var.
func (s *State) configIndexSet(txn *memdb.Txn, id []byte, value *vagrant_server.ConfigVar) error { func (s *State) configIndexSet(txn *memdb.Txn, id []byte, value *vagrant_server.ConfigVar) error {
var project, application string var basis, project, target string
var runner *vagrant_server.Ref_Runner var runner *vagrant_server.Ref_Runner
switch scope := value.Scope.(type) { switch scope := value.Scope.(type) {
//TODO(spox): Does this need to be machine? Need basis too? case *vagrant_server.ConfigVar_Basis:
//case *vagrant_server.ConfigVar_Application: basis = scope.Basis.ResourceId
case *vagrant_server.ConfigVar_Project: case *vagrant_server.ConfigVar_Project:
project = scope.Project.ResourceId project = scope.Project.ResourceId
case *vagrant_server.ConfigVar_Target:
target = scope.Target.ResourceId
case *vagrant_server.ConfigVar_Runner: case *vagrant_server.ConfigVar_Runner:
runner = scope.Runner runner = scope.Runner
default: default:
panic("unknown scope") panic("unknown scope")
} }
record := &configIndexRecord{ record := &configIndexRecord{
Id: string(id), Id: string(id),
Project: project, Basis: basis,
Application: application, Project: project,
Name: value.Name, Target: target,
Runner: runner != nil, Name: value.Name,
RunnerRef: runner, Runner: runner != nil,
RunnerRef: runner,
} }
// If we have no value, we delete from the memdb index // If we have no value, we delete from the memdb index
@ -288,33 +366,48 @@ func (s *State) configIndexSet(txn *memdb.Txn, id []byte, value *vagrant_server.
} }
// configIndexInit initializes the config index from persisted data. // configIndexInit initializes the config index from persisted data.
func (s *State) configIndexInit(dbTxn *bolt.Tx, memTxn *memdb.Txn) error { func (s *State) configIndexInit(memTxn *memdb.Txn) error {
bucket := dbTxn.Bucket(configBucket) var cfgs []Config
return bucket.ForEach(func(k, v []byte) error { result := s.db.Find(&cfgs)
var value vagrant_server.ConfigVar if result.Error != nil {
if err := proto.Unmarshal(v, &value); err != nil { return result.Error
return err }
} for _, c := range cfgs {
if err := s.configIndexSet(memTxn, k, &value); err != nil { p := c.ToProto()
if err := s.configIndexSet(memTxn, s.configVarId(p), p); err != nil {
return err return err
} }
}
return nil return nil
})
} }
func (s *State) configVarId(v *vagrant_server.ConfigVar) []byte { func (s *State) configVarId(v *vagrant_server.ConfigVar) []byte {
switch scope := v.Scope.(type) { switch scope := v.Scope.(type) {
// TODO(spox): same as above with machine/basis/etc case *vagrant_server.ConfigVar_Basis:
//case *vagrant_server.ConfigVar_Application: return []byte(
fmt.Sprintf("%v/%v",
scope.Basis.Name,
v.Name,
),
)
case *vagrant_server.ConfigVar_Project: case *vagrant_server.ConfigVar_Project:
return []byte(fmt.Sprintf("%s/%s/%s", return []byte(
scope.Project.ResourceId, fmt.Sprintf("%v/%v/%v",
"", scope.Project.Basis.ResourceId,
v.Name, scope.Project.ResourceId,
)) v.Name,
),
)
case *vagrant_server.ConfigVar_Target:
return []byte(
fmt.Sprintf("%v/%v/%v/%v",
scope.Target.Project.Basis.ResourceId,
scope.Target.Project.ResourceId,
scope.Target.ResourceId,
v.Name,
),
)
case *vagrant_server.ConfigVar_Runner: case *vagrant_server.ConfigVar_Runner:
var t string var t string
switch scope.Runner.Target.(type) { switch scope.Runner.Target.(type) {
@ -345,7 +438,26 @@ func configIndexSchema() *memdb.TableSchema {
Unique: true, Unique: true,
Indexer: &memdb.StringFieldIndex{ Indexer: &memdb.StringFieldIndex{
Field: "Id", Field: "Id",
Lowercase: true, Lowercase: false,
},
},
configIndexBasisIndexName: {
Name: configIndexBasisIndexName,
AllowMissing: true,
Unique: false,
Indexer: &memdb.CompoundIndex{
Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "Basis",
Lowercase: true,
},
&memdb.StringFieldIndex{
Field: "Name",
Lowercase: true,
},
},
}, },
}, },
@ -355,6 +467,11 @@ func configIndexSchema() *memdb.TableSchema {
Unique: false, Unique: false,
Indexer: &memdb.CompoundIndex{ Indexer: &memdb.CompoundIndex{
Indexes: []memdb.Indexer{ Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "Basis",
Lowercase: true,
},
&memdb.StringFieldIndex{ &memdb.StringFieldIndex{
Field: "Project", Field: "Project",
Lowercase: true, Lowercase: true,
@ -368,19 +485,24 @@ func configIndexSchema() *memdb.TableSchema {
}, },
}, },
configIndexApplicationIndexName: { configIndexTargetIndexName: {
Name: configIndexApplicationIndexName, Name: configIndexTargetIndexName,
AllowMissing: true, AllowMissing: true,
Unique: false, Unique: false,
Indexer: &memdb.CompoundIndex{ Indexer: &memdb.CompoundIndex{
Indexes: []memdb.Indexer{ Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "Basis",
Lowercase: true,
},
&memdb.StringFieldIndex{ &memdb.StringFieldIndex{
Field: "Project", Field: "Project",
Lowercase: true, Lowercase: true,
}, },
&memdb.StringFieldIndex{ &memdb.StringFieldIndex{
Field: "Application", Field: "Target",
Lowercase: true, Lowercase: true,
}, },
@ -414,18 +536,20 @@ func configIndexSchema() *memdb.TableSchema {
} }
const ( const (
configIndexTableName = "config-index" configIndexTableName = "config-index"
configIndexIdIndexName = "id" configIndexIdIndexName = "id"
configIndexProjectIndexName = "project" configIndexBasisIndexName = "basis"
configIndexApplicationIndexName = "application" configIndexProjectIndexName = "project"
configIndexRunnerIndexName = "runner" configIndexTargetIndexName = "target"
configIndexRunnerIndexName = "runner"
) )
type configIndexRecord struct { type configIndexRecord struct {
Id string Id string
Project string Basis string
Application string Project string
Name string Target string
Runner bool // true if this is a runner config Name string
RunnerRef *vagrant_server.Ref_Runner Runner bool // true if this is a runner config
RunnerRef *vagrant_server.Ref_Runner
} }

View File

@ -5,10 +5,8 @@ import (
"time" "time"
"github.com/hashicorp/go-memdb" "github.com/hashicorp/go-memdb"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/stretchr/testify/require"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
"github.com/stretchr/testify/require"
) )
func TestConfig(t *testing.T) { func TestConfig(t *testing.T) {
@ -17,13 +15,12 @@ func TestConfig(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
// Create a build // Create a build
require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{ require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Project{ Scope: &vagrant_server.ConfigVar_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ Project: projRef,
ResourceId: "foo",
},
}, },
Name: "foo", Name: "foo",
@ -34,7 +31,7 @@ func TestConfig(t *testing.T) {
// Get it exactly // Get it exactly
vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{ vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{
Scope: &vagrant_server.ConfigGetRequest_Project{ Scope: &vagrant_server.ConfigGetRequest_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ResourceId: "foo"}, Project: projRef,
}, },
Prefix: "foo", Prefix: "foo",
@ -47,7 +44,7 @@ func TestConfig(t *testing.T) {
// Get it via a prefix match // Get it via a prefix match
vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{ vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{
Scope: &vagrant_server.ConfigGetRequest_Project{ Scope: &vagrant_server.ConfigGetRequest_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ResourceId: "foo"}, Project: projRef,
}, },
Prefix: "", Prefix: "",
@ -60,7 +57,7 @@ func TestConfig(t *testing.T) {
// non-matching prefix // non-matching prefix
vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{ vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{
Scope: &vagrant_server.ConfigGetRequest_Project{ Scope: &vagrant_server.ConfigGetRequest_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ResourceId: "foo"}, Project: projRef,
}, },
Prefix: "bar", Prefix: "bar",
@ -76,13 +73,13 @@ func TestConfig(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
// Create a build // Create a build
require.NoError(s.ConfigSet( require.NoError(s.ConfigSet(
&vagrant_server.ConfigVar{ &vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Project{ Scope: &vagrant_server.ConfigVar_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ Project: projRef,
ResourceId: "foo",
},
}, },
Name: "global", Name: "global",
@ -90,9 +87,7 @@ func TestConfig(t *testing.T) {
}, },
&vagrant_server.ConfigVar{ &vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Project{ Scope: &vagrant_server.ConfigVar_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ Project: projRef,
ResourceId: "foo",
},
}, },
Name: "hello", Name: "hello",
@ -104,9 +99,7 @@ func TestConfig(t *testing.T) {
// Get our merged variables // Get our merged variables
vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{ vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{
Scope: &vagrant_server.ConfigGetRequest_Project{ Scope: &vagrant_server.ConfigGetRequest_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ Project: projRef,
ResourceId: "foo",
},
}, },
}) })
require.NoError(err) require.NoError(err)
@ -122,9 +115,7 @@ func TestConfig(t *testing.T) {
// Get project scoped variables. This should return everything. // Get project scoped variables. This should return everything.
vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{ vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{
Scope: &vagrant_server.ConfigGetRequest_Project{ Scope: &vagrant_server.ConfigGetRequest_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ Project: projRef,
ResourceId: "foo",
},
}, },
}) })
require.NoError(err) require.NoError(err)
@ -138,12 +129,12 @@ func TestConfig(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
// Create a var // Create a var
require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{ require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Project{ Scope: &vagrant_server.ConfigVar_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ Project: projRef,
ResourceId: "foo",
},
}, },
Name: "foo", Name: "foo",
@ -154,7 +145,7 @@ func TestConfig(t *testing.T) {
// Get it exactly // Get it exactly
vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{ vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{
Scope: &vagrant_server.ConfigGetRequest_Project{ Scope: &vagrant_server.ConfigGetRequest_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ResourceId: "foo"}, Project: projRef,
}, },
Prefix: "foo", Prefix: "foo",
@ -166,9 +157,7 @@ func TestConfig(t *testing.T) {
// Delete it // Delete it
require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{ require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Project{ Scope: &vagrant_server.ConfigVar_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ Project: projRef,
ResourceId: "foo",
},
}, },
Name: "foo", Name: "foo",
@ -179,7 +168,7 @@ func TestConfig(t *testing.T) {
// Get it exactly // Get it exactly
vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{ vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{
Scope: &vagrant_server.ConfigGetRequest_Project{ Scope: &vagrant_server.ConfigGetRequest_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ResourceId: "foo"}, Project: projRef,
}, },
Prefix: "foo", Prefix: "foo",
@ -195,6 +184,8 @@ func TestConfig(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
// Create the config // Create the config
require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{ require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Runner{ Scope: &vagrant_server.ConfigVar_Runner{
@ -212,9 +203,7 @@ func TestConfig(t *testing.T) {
// Create a var that shouldn't match // Create a var that shouldn't match
require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{ require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Project{ Scope: &vagrant_server.ConfigVar_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ Project: projRef,
ResourceId: "foo",
},
}, },
Name: "bar", Name: "bar",
@ -267,6 +256,8 @@ func TestConfig(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
// Create the config // Create the config
require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{ require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Runner{ Scope: &vagrant_server.ConfigVar_Runner{
@ -286,9 +277,7 @@ func TestConfig(t *testing.T) {
// Create a var that shouldn't match // Create a var that shouldn't match
require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{ require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Project{ Scope: &vagrant_server.ConfigVar_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ Project: projRef,
ResourceId: "foo",
},
}, },
Name: "bar", Name: "bar",
@ -380,12 +369,14 @@ func TestConfigWatch(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
// Get it with watch // Get it with watch
vs, err := s.ConfigGetWatch(&vagrant_server.ConfigGetRequest{ vs, err := s.ConfigGetWatch(&vagrant_server.ConfigGetRequest{
Scope: &vagrant_server.ConfigGetRequest_Project{ Scope: &vagrant_server.ConfigGetRequest_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ResourceId: "foo"}, Project: projRef,
}, },
Prefix: "foo", Prefix: "foo",
@ -399,9 +390,7 @@ func TestConfigWatch(t *testing.T) {
// Create a config // Create a config
require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{ require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Project{ Scope: &vagrant_server.ConfigVar_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ Project: projRef,
ResourceId: "foo",
},
}, },
Name: "foo", Name: "foo",

View File

@ -0,0 +1,808 @@
package state
import (
"fmt"
"reflect"
"time"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
"github.com/mitchellh/mapstructure"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
type Decoder struct {
*mapstructure.Decoder
}
func NewDecoder(config *mapstructure.DecoderConfig) (*Decoder, error) {
intD, err := mapstructure.NewDecoder(config)
if err != nil {
return nil, err
}
return &Decoder{intD}, nil
}
func (d *Decoder) SoftDecode(input interface{}) error {
v := reflect.Indirect(reflect.ValueOf(input))
t := v.Type()
if v.Kind() != reflect.Struct {
return d.Decode(input)
}
newFields := []reflect.StructField{}
for i := 0; i < t.NumField(); i++ {
structField := t.Field(i)
if !structField.IsExported() {
continue
}
indirectVal := reflect.Indirect(v.FieldByName(structField.Name))
if !indirectVal.IsValid() {
continue
}
val := indirectVal.Interface()
fieldType := structField.Type
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}
defaultZero := reflect.Zero(fieldType).Interface()
if reflect.DeepEqual(val, defaultZero) {
continue
}
newField := reflect.StructField{
Name: structField.Name,
PkgPath: structField.PkgPath,
Tag: structField.Tag,
Type: structField.Type,
}
newFields = append(newFields, newField)
}
newStruct := reflect.StructOf(newFields)
newInput := reflect.New(newStruct).Elem()
err := mapstructure.Decode(input, newInput.Addr().Interface())
if err != nil {
// This should not happen, but if it does, we need to bail
panic("failed to generate decode copy: " + err.Error())
}
// pval := v.FieldByName("Project").Interface()
// nval := newInput.FieldByName("Project").Interface()
// e := d.Decode(newInput.Interface())
// if e != nil {
// panic(e)
// }
// nval2 := newInput.FieldByName("Project").Interface()
// pval2 := v.FieldByName("Project").Interface()
// panic(fmt.Sprintf("\n\nold value: %#v\nnew value: %#v\n\n --- \nold value: %#v\nnew value: %#v\n",
// pval, nval, pval2, nval2))
return d.Decode(newInput.Interface())
}
// Creates a decoder with all our custom hooks. This decoder can
// be used for converting models to protobuf messages and converting
// protobuf messages to models.
func (s *State) decoder(output interface{}) *Decoder {
config := mapstructure.DecoderConfig{
DecodeHook: mapstructure.ComposeDecodeHookFunc(
projectToProtoRefHookFunc,
projectToProtoHookFunc,
s.projectFromProtoHookFunc,
s.projectFromProtoRefHookFunc,
basisToProtoHookFunc,
basisToProtoRefHookFunc,
s.basisFromProtoHookFunc,
s.basisFromProtoRefHookFunc,
targetToProtoHookFunc,
targetToProtoRefHookFunc,
s.targetFromProtoHookFunc,
s.targetFromProtoRefHookFunc,
vagrantfileToProtoHookFunc,
s.vagrantfileFromProtoHookFunc,
runnerToProtoHookFunc,
s.runnerFromProtoHookFunc,
protobufToProtoValueHookFunc,
protobufToProtoRawHookFunc,
boxToProtoHookFunc,
boxToProtoRefHookFunc,
s.boxFromProtoHookFunc,
s.boxFromProtoRefHookFunc,
timeToProtoHookFunc,
timeFromProtoHookFunc,
s.scopeFromProtoHookFunc,
scopeToProtoHookFunc,
protoValueToProtoHookFunc,
protoRawToProtoHookFunc,
s.componentFromProtoHookFunc,
componentToProtoHookFunc,
),
Result: output,
}
d, err := NewDecoder(&config)
if err != nil {
panic("failed to create mapstructure decoder: " + err.Error())
}
return d
}
// Decodes input into output structure using custom decoder
func (s *State) decode(input, output interface{}) error {
return s.decoder(output).Decode(input)
}
func (s *State) softDecode(input, output interface{}) error {
return s.decoder(output).SoftDecode(input)
}
// Creates a decoder with some of our custom hooks. This can be used
// for converting models to protobuf messages but cannot be used for
// converting protobuf messages to models.
func decoder(output interface{}) *Decoder {
config := mapstructure.DecoderConfig{
DecodeHook: mapstructure.ComposeDecodeHookFunc(
projectToProtoRefHookFunc,
projectToProtoHookFunc,
basisToProtoHookFunc,
basisToProtoRefHookFunc,
targetToProtoHookFunc,
targetToProtoRefHookFunc,
vagrantfileToProtoHookFunc,
runnerToProtoHookFunc,
protobufToProtoValueHookFunc,
protobufToProtoRawHookFunc,
boxToProtoHookFunc,
boxToProtoRefHookFunc,
timeToProtoHookFunc,
timeFromProtoHookFunc,
scopeToProtoHookFunc,
protoValueToProtoHookFunc,
protoRawToProtoHookFunc,
componentToProtoHookFunc,
),
Result: output,
}
d, err := NewDecoder(&config)
if err != nil {
panic("failed to create mapstructure decoder: " + err.Error())
}
return d
}
// Decodes input into output structure using custom decoder
func decode(input, output interface{}) error {
return decoder(output).Decode(input)
}
func softDecode(input, output interface{}) error {
return decoder(output).SoftDecode(input)
}
// Everything below here are converters
func projectToProtoRefHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Project)(nil)) ||
to != reflect.TypeOf((*vagrant_plugin_sdk.Ref_Project)(nil)) {
return data, nil
}
p, ok := data.(*Project)
if !ok {
return nil, fmt.Errorf("cannot serialize project ref, wrong type (%T)", data)
}
return p.ToProtoRef(), nil
}
func projectToProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Project)(nil)) ||
to != reflect.TypeOf((*vagrant_server.Project)(nil)) {
return data, nil
}
p, ok := data.(*Project)
if !ok {
return nil, fmt.Errorf("cannot serialize project, wrong type (%T)", data)
}
return p.ToProto(), nil
}
func (s *State) projectFromProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_server.Project)(nil)) ||
to != reflect.TypeOf((*Project)(nil)) {
return data, nil
}
p, ok := data.(*vagrant_server.Project)
if !ok {
return nil, fmt.Errorf("cannot deserialize project, wrong type (%T)", data)
}
return s.ProjectFromProto(p)
}
func (s *State) projectFromProtoRefHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_plugin_sdk.Ref_Project)(nil)) ||
to != reflect.TypeOf((*Project)(nil)) {
return data, nil
}
p, ok := data.(*vagrant_plugin_sdk.Ref_Project)
if !ok {
return nil, fmt.Errorf("cannot deserialize project ref, wrong type (%T)", data)
}
return s.ProjectFromProtoRef(p)
}
func basisToProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Basis)(nil)) ||
to != reflect.TypeOf((*vagrant_server.Basis)(nil)) {
return data, nil
}
b, ok := data.(*Basis)
if !ok {
return nil, fmt.Errorf("cannot serialize basis, wrong type (%T)", data)
}
return b.ToProto(), nil
}
func basisToProtoRefHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Basis)(nil)) ||
to != reflect.TypeOf((*vagrant_plugin_sdk.Ref_Basis)(nil)) {
return data, nil
}
b, ok := data.(*Basis)
if !ok {
return nil, fmt.Errorf("cannot serialize basis ref, wrong type (%T)", data)
}
return b.ToProtoRef(), nil
}
func (s *State) basisFromProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_server.Basis)(nil)) ||
to != reflect.TypeOf((*Basis)(nil)) {
return data, nil
}
b, ok := data.(*vagrant_server.Basis)
if !ok {
return nil, fmt.Errorf("cannot deserialize basis, wrong type (%T)", data)
}
return s.BasisFromProto(b)
}
func (s *State) basisFromProtoRefHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_plugin_sdk.Ref_Basis)(nil)) ||
to != reflect.TypeOf((*Basis)(nil)) {
return data, nil
}
b, ok := data.(*vagrant_plugin_sdk.Ref_Basis)
if !ok {
return nil, fmt.Errorf("cannot deserialize basis ref, wrong type (%T)", data)
}
return s.BasisFromProtoRef(b)
}
func targetToProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Target)(nil)) ||
to != reflect.TypeOf((*vagrant_server.Target)(nil)) {
return data, nil
}
t, ok := data.(*Target)
if !ok {
return nil, fmt.Errorf("cannot serialize target, wrong type (%T)", data)
}
return t.ToProto(), nil
}
func targetToProtoRefHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Target)(nil)) ||
to != reflect.TypeOf((*vagrant_plugin_sdk.Ref_Target)(nil)) {
return data, nil
}
t, ok := data.(*Target)
if !ok {
return nil, fmt.Errorf("cannot serialize target ref, wrong type (%T)", data)
}
return t.ToProtoRef(), nil
}
func (s *State) targetFromProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_server.Target)(nil)) ||
to != reflect.TypeOf((*Target)(nil)) {
return data, nil
}
t, ok := data.(*vagrant_server.Target)
if !ok {
return nil, fmt.Errorf("cannot deserialize target, wrong type (%T)", data)
}
return s.TargetFromProto(t)
}
func (s *State) targetFromProtoRefHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_plugin_sdk.Ref_Target)(nil)) ||
to != reflect.TypeOf((*Target)(nil)) {
return data, nil
}
t, ok := data.(*vagrant_plugin_sdk.Ref_Target)
if !ok {
return nil, fmt.Errorf("cannot deserialize target ref, wrong type (%T)", data)
}
return s.TargetFromProtoRef(t)
}
func vagrantfileToProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Vagrantfile)(nil)) ||
to != reflect.TypeOf((*vagrant_server.Vagrantfile)(nil)) {
return data, nil
}
v, ok := data.(*Vagrantfile)
if !ok {
return nil, fmt.Errorf("cannot serialize vagrantfile, wrong type (%T)", data)
}
return v.ToProto(), nil
}
func (s *State) vagrantfileFromProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_server.Vagrantfile)(nil)) ||
to != reflect.TypeOf((*Vagrantfile)(nil)) {
return data, nil
}
v, ok := data.(*vagrant_server.Vagrantfile)
if !ok {
return nil, fmt.Errorf("cannot deserialize vagrantfile, wrong type (%T)", data)
}
return s.VagrantfileFromProto(v), nil
}
func runnerToProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Runner)(nil)) ||
to != reflect.TypeOf((*vagrant_server.Runner)(nil)) {
return data, nil
}
r, ok := data.(*Runner)
if !ok {
return nil, fmt.Errorf("cannot serialize runner, wrong type (%T)", data)
}
return r.ToProto(), nil
}
func (s *State) runnerFromProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_server.Runner)(nil)) ||
to != reflect.TypeOf((*Runner)(nil)) {
return data, nil
}
r, ok := data.(*vagrant_server.Runner)
if !ok {
return nil, fmt.Errorf("cannot deserialize runner, wrong type (%T)", data)
}
return s.RunnerFromProto(r)
}
func protobufToProtoValueHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if to != reflect.TypeOf((*ProtoValue)(nil)) {
return data, nil
}
p, ok := data.(proto.Message)
if ok {
return &ProtoValue{Message: p}, nil
}
switch v := data.(type) {
case *vagrant_server.Job_Init:
return &ProtoValue{Message: v.Init}, nil
case *vagrant_server.Job_Command:
return &ProtoValue{Message: v.Command}, nil
case *vagrant_server.Job_Noop_:
return &ProtoValue{Message: v.Noop}, nil
}
return data, nil
}
func protobufToProtoRawHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if !from.Implements(reflect.TypeOf((*proto.Message)(nil)).Elem()) ||
to != reflect.TypeOf((*ProtoRaw)(nil)) {
return data, nil
}
p, ok := data.(proto.Message)
if !ok {
return nil, fmt.Errorf("cannot wrap into protovalue, wrong type (%T)", data)
}
return &ProtoRaw{Message: p}, nil
}
func boxToProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Box)(nil)) ||
to != reflect.TypeOf((*vagrant_server.Box)(nil)) {
return data, nil
}
b, ok := data.(*Box)
if !ok {
return nil, fmt.Errorf("cannot serialize box, wrong type (%T)", data)
}
return b.ToProto(), nil
}
func boxToProtoRefHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Box)(nil)) ||
to != reflect.TypeOf((*vagrant_plugin_sdk.Ref_Box)(nil)) {
return data, nil
}
b, ok := data.(*Box)
if !ok {
return nil, fmt.Errorf("cannot serialize box ref, wrong type (%T)", data)
}
return b.ToProtoRef(), nil
}
func (s *State) boxFromProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_server.Box)(nil)) ||
to != reflect.TypeOf((*Box)(nil)) {
return data, nil
}
b, ok := data.(*vagrant_server.Box)
if !ok {
return nil, fmt.Errorf("cannot deserialize box, wrong type (%T)", data)
}
return s.BoxFromProto(b)
}
func (s *State) boxFromProtoRefHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_plugin_sdk.Ref_Box)(nil)) ||
to != reflect.TypeOf((*Box)(nil)) {
return data, nil
}
b, ok := data.(*vagrant_plugin_sdk.Ref_Box)
if !ok {
return nil, fmt.Errorf("cannot deserialize box ref, wrong type (%T)", data)
}
return s.BoxFromProtoRef(b)
}
func timeToProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*time.Time)(nil)) ||
to != reflect.TypeOf((*timestamppb.Timestamp)(nil)) {
return data, nil
}
t, ok := data.(*time.Time)
if !ok {
return nil, fmt.Errorf("cannot serialize time, wrong type (%T)", data)
}
return timestamppb.New(*t), nil
}
func timeFromProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*timestamppb.Timestamp)(nil)) ||
to != reflect.TypeOf((*time.Time)(nil)) {
return data, nil
}
t, ok := data.(*timestamppb.Timestamp)
if !ok {
return nil, fmt.Errorf("cannot deserialize time, wrong type (%T)", data)
}
at := t.AsTime()
return &at, nil
}
func protoValueToProtoHookFunc(
from, to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*ProtoValue)(nil)) {
return data, nil
}
p, ok := data.(*ProtoValue)
if !ok {
return nil, fmt.Errorf("invalid proto value (%s -> %s)", from, to)
}
if p.Message == nil {
return nil, fmt.Errorf("proto value contents is nil (destination: %s)", to)
}
if reflect.ValueOf(p.Message).Type().AssignableTo(to) {
return p.Message, nil
}
switch v := p.Message.(type) {
// Start with Job oneof types
case *vagrant_server.Job_InitOp:
if reflect.TypeOf((*vagrant_server.Job_Init)(nil)).AssignableTo(to) {
return &vagrant_server.Job_Init{Init: v}, nil
}
case *vagrant_server.Job_CommandOp:
if reflect.TypeOf((*vagrant_server.Job_Command)(nil)).AssignableTo(to) {
return &vagrant_server.Job_Command{Command: v}, nil
}
case *vagrant_server.Job_Noop:
if reflect.TypeOf((*vagrant_server.Job_Noop_)(nil)).AssignableTo(to) {
return &vagrant_server.Job_Noop_{Noop: v}, nil
}
}
return data, nil
}
func protoRawToProtoHookFunc(
from, to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*ProtoRaw)(nil)) {
return data, nil
}
p, ok := data.(*ProtoRaw)
if !ok {
return nil, fmt.Errorf("invalid proto value (%s -> %s)", from, to)
}
if !reflect.ValueOf(p.Message).Type().AssignableTo(to) {
return data, nil
}
return p.Message, nil
}
func (s *State) scopeFromProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if (from != reflect.TypeOf((*vagrant_server.Job_Basis)(nil)) &&
from != reflect.TypeOf((*vagrant_server.Job_Project)(nil)) &&
from != reflect.TypeOf((*vagrant_server.Job_Target)(nil)) &&
from != reflect.TypeOf((*vagrant_server.ConfigVar_Basis)(nil)) &&
from != reflect.TypeOf((*vagrant_server.ConfigVar_Project)(nil)) &&
from != reflect.TypeOf((*vagrant_server.ConfigVar_Target)(nil))) ||
!to.Implements(reflect.TypeOf((*scope)(nil)).Elem()) {
return data, nil
}
var result scope
var err error
switch v := data.(type) {
case *vagrant_server.Job_Basis:
result, err = s.BasisFromProtoRef(v.Basis)
case *vagrant_server.ConfigVar_Basis:
result, err = s.BasisFromProtoRef(v.Basis)
case *vagrant_server.Job_Project:
result, err = s.ProjectFromProtoRef(v.Project)
case *vagrant_server.ConfigVar_Project:
result, err = s.ProjectFromProtoRef(v.Project)
case *vagrant_server.Job_Target:
result, err = s.TargetFromProtoRef(v.Target)
case *vagrant_server.ConfigVar_Target:
result, err = s.TargetFromProtoRef(v.Target)
default:
err = fmt.Errorf("invalid job scope type (%T)", data)
}
if err != nil {
return nil, err
}
return result, nil
}
func scopeToProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if !from.Implements(reflect.TypeOf((*scope)(nil)).Elem()) {
return data, nil
}
switch v := data.(type) {
case *Basis:
if reflect.TypeOf((*vagrant_server.Job_Basis)(nil)).AssignableTo(to) {
return &vagrant_server.Job_Basis{Basis: v.ToProtoRef()}, nil
}
if reflect.TypeOf((*vagrant_server.ConfigVar_Basis)(nil)).AssignableTo(to) {
return &vagrant_server.ConfigVar_Basis{Basis: v.ToProtoRef()}, nil
}
case *Project:
if reflect.TypeOf((*vagrant_server.Job_Project)(nil)).AssignableTo(to) {
return &vagrant_server.Job_Project{Project: v.ToProtoRef()}, nil
}
if reflect.TypeOf((*vagrant_server.ConfigVar_Project)(nil)).AssignableTo(to) {
return &vagrant_server.ConfigVar_Project{Project: v.ToProtoRef()}, nil
}
case *Target:
if reflect.TypeOf((*vagrant_server.Job_Target)(nil)).AssignableTo(to) {
return &vagrant_server.Job_Target{Target: v.ToProtoRef()}, nil
}
if reflect.TypeOf((*vagrant_server.ConfigVar_Target)(nil)).AssignableTo(to) {
return &vagrant_server.ConfigVar_Target{Target: v.ToProtoRef()}, nil
}
}
return data, nil
}
func (s *State) componentFromProtoHookFunc(
from, to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_server.Component)(nil)) ||
to != reflect.TypeOf((*Component)(nil)) {
return data, nil
}
c, ok := data.(*vagrant_server.Component)
if !ok {
return nil, fmt.Errorf("cannot deserialize component, wrong type (%T)", data)
}
return s.ComponentFromProto(c)
}
func componentToProtoHookFunc(
from, to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Component)(nil)) ||
to != reflect.TypeOf((*vagrant_server.Component)(nil)) {
return data, nil
}
c, ok := data.(*Component)
if !ok {
return nil, fmt.Errorf("cannot serialize component, wrong type (%T)", data)
}
return c.ToProto(), nil
}

View File

@ -0,0 +1,40 @@
package state
import (
"testing"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
"github.com/stretchr/testify/require"
)
func TestSoftDecode(t *testing.T) {
t.Run("Decodes nothing when unset", func(t *testing.T) {
require := require.New(t)
s := TestState(t)
defer s.Close()
tref := &vagrant_plugin_sdk.Target{}
var target Target
err := s.softDecode(tref, &target)
require.NoError(err)
require.Equal(target, Target{})
})
t.Run("Decodes project reference", func(t *testing.T) {
require := require.New(t)
s := TestState(t)
defer s.Close()
pref := testProject(t, s)
tproto := &vagrant_server.Target{
Project: pref,
}
var target Target
err := s.softDecode(tproto, &target)
require.NoError(err)
require.NotNil(target.Project)
require.Equal(*target.Project.ResourceId, tproto.Project.ResourceId)
})
}

View File

@ -2,26 +2,32 @@ package state
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"reflect" "reflect"
"sort" "sort"
"time" "time"
"github.com/hashicorp/go-memdb" "github.com/hashicorp/go-memdb"
bolt "go.etcd.io/bbolt"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk" "github.com/hashicorp/vagrant/internal/server"
"github.com/hashicorp/vagrant/internal/server/logbuffer" "github.com/hashicorp/vagrant/internal/server/logbuffer"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
) )
var ( func init() {
jobBucket = []byte("jobs") models = append(models, &InternalJob{})
dbIndexers = append(dbIndexers, (*State).jobIndexInit)
schemas = append(schemas, jobSchema)
}
var (
jobBucket = []byte("jobs")
jobWaitingTimeout = 2 * time.Minute jobWaitingTimeout = 2 * time.Minute
jobHeartbeatTimeout = 2 * time.Minute jobHeartbeatTimeout = 2 * time.Minute
) )
@ -35,10 +41,148 @@ const (
maximumJobsInMem = 10000 maximumJobsInMem = 10000
) )
func init() { type JobState uint8
dbBuckets = append(dbBuckets, jobBucket)
dbIndexers = append(dbIndexers, (*State).jobIndexInit) const (
schemas = append(schemas, jobSchema) JOB_STATE_UNKNOWN JobState = JobState(vagrant_server.Job_UNKNOWN)
JOB_STATE_QUEUED = JobState(vagrant_server.Job_QUEUED)
JOB_STATE_WAITING = JobState(vagrant_server.Job_WAITING)
JOB_STATE_RUNNING = JobState(vagrant_server.Job_RUNNING)
JOB_STATE_ERROR = JobState(vagrant_server.Job_ERROR)
JOB_STATE_SUCCESS = JobState(vagrant_server.Job_SUCCESS)
)
type InternalJob struct {
gorm.Model
AssignTime *time.Time
AckTime *time.Time
AssignedRunnerID uint `mapstructure:"-"`
AssignedRunner *Runner
CancelTime *time.Time
CompleteTime *time.Time
DataSource *ProtoValue
DataSourceOverrides MetadataSet
Error *ProtoValue
ExpireTime *time.Time
Labels MetadataSet
Jid *string `gorm:"uniqueIndex" mapstructure:"Id"`
Operation *ProtoValue `mapstructure:"Operation"`
QueueTime *time.Time
Result *ProtoValue
Scope scope `gorm:"-:all"`
ScopeID uint `mapstructure:"-"`
ScopeType string `mapstructure:"-"`
State JobState
TargetRunner *ProtoValue
}
// Job should come with an ID assigned to it, but if it doesn't for
// some reason, we assign one now.
func (i *InternalJob) BeforeCreate(tx *gorm.DB) error {
if i.Jid == nil {
id, err := server.Id()
if err != nil {
return err
}
i.Jid = &id
}
return nil
}
// If the job has a scope assigned to it, persist it.
func (i *InternalJob) BeforeSave(tx *gorm.DB) (err error) {
if i.Scope == nil {
i.ScopeID = 0
i.ScopeType = ""
return nil
}
switch v := i.Scope.(type) {
case *Basis:
i.ScopeID = v.ID
i.ScopeType = "basis"
case *Project:
i.ScopeID = v.ID
i.ScopeType = "project"
case *Target:
i.ScopeID = v.ID
i.ScopeType = "target"
default:
return fmt.Errorf("unknown scope type (%T)", i.Scope)
}
return nil
}
// If the job has a scope, load it.
func (i *InternalJob) AfterFind(tx *gorm.DB) (err error) {
if i.ScopeID == 0 {
return nil
}
switch i.ScopeType {
case "basis":
var b Basis
result := tx.Preload(clause.Associations).
First(&b, &Basis{Model: gorm.Model{ID: i.ScopeID}})
if result.Error != nil {
return result.Error
}
i.Scope = &b
case "project":
var p Project
result := tx.Preload(clause.Associations).
First(&p, &Project{Model: gorm.Model{ID: i.ScopeID}})
if result.Error != nil {
return result.Error
}
i.Scope = &p
case "target":
var t Target
result := tx.Preload(clause.Associations).
First(&t, &Target{Model: gorm.Model{ID: i.ScopeID}})
if result.Error != nil {
return result.Error
}
i.Scope = &t
default:
return fmt.Errorf("unknown scope type (%s)", i.ScopeType)
}
return nil
}
// Convert job to a protobuf message
func (i *InternalJob) ToProto() *vagrant_server.Job {
if i == nil {
return nil
}
var j vagrant_server.Job
err := decode(i, &j)
if err != nil {
panic("failed to decode job: " + err.Error())
}
return &j
}
func (s *State) InternalJobFromProto(job *vagrant_server.Job) (*InternalJob, error) {
if job == nil {
return nil, ErrEmptyProtoArgument
}
if job.Id == "" {
return nil, gorm.ErrRecordNotFound
}
var j InternalJob
result := s.search().First(&j, &InternalJob{Jid: &job.Id})
if result.Error != nil {
return nil, result.Error
}
return &j, nil
} }
func jobSchema() *memdb.TableSchema { func jobSchema() *memdb.TableSchema {
@ -115,9 +259,9 @@ type jobIndex struct {
// The basis/project/machine that this job is part of. This is used // The basis/project/machine that this job is part of. This is used
// to determine if the job is blocked. See job_assigned.go for more details. // to determine if the job is blocked. See job_assigned.go for more details.
Basis *vagrant_plugin_sdk.Ref_Basis Scope interface {
Project *vagrant_plugin_sdk.Ref_Project GetResourceId() string
Target *vagrant_plugin_sdk.Ref_Target }
// QueueTime is the time that the job was queued. // QueueTime is the time that the job was queued.
QueueTime time.Time QueueTime time.Time
@ -171,14 +315,16 @@ func (s *State) JobCreate(jobpb *vagrant_server.Job) error {
txn := s.inmem.Txn(true) txn := s.inmem.Txn(true)
defer txn.Abort() defer txn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) error { err := s.jobCreate(txn, jobpb)
return s.jobCreate(dbTxn, txn, jobpb)
})
if err == nil { if err == nil {
txn.Commit() txn.Commit()
} }
return err if err != nil {
return lookupErrorToStatus("job", err)
}
return nil
} }
// JobList returns the list of jobs. // JobList returns the list of jobs.
@ -188,7 +334,7 @@ func (s *State) JobList() ([]*vagrant_server.Job, error) {
iter, err := memTxn.Get(jobTableName, jobIdIndexName+"_prefix", "") iter, err := memTxn.Get(jobTableName, jobIdIndexName+"_prefix", "")
if err != nil { if err != nil {
return nil, err return nil, lookupErrorToStatus("job", err)
} }
var result []*vagrant_server.Job var result []*vagrant_server.Job
@ -199,11 +345,10 @@ func (s *State) JobList() ([]*vagrant_server.Job, error) {
} }
idx := next.(*jobIndex) idx := next.(*jobIndex)
var job *vagrant_server.Job job, err := s.jobById(idx.Id)
err = s.db.View(func(dbTxn *bolt.Tx) error { if err != nil {
job, err = s.jobById(dbTxn, idx.Id) return nil, lookupErrorToStatus("job", err)
return err }
})
result = append(result, job) result = append(result, job)
} }
@ -220,7 +365,7 @@ func (s *State) JobById(id string, ws memdb.WatchSet) (*Job, error) {
watchCh, raw, err := memTxn.FirstWatch(jobTableName, jobIdIndexName, id) watchCh, raw, err := memTxn.FirstWatch(jobTableName, jobIdIndexName, id)
if err != nil { if err != nil {
return nil, err return nil, lookupErrorToStatus("job", err)
} }
ws.Add(watchCh) ws.Add(watchCh)
@ -235,20 +380,19 @@ func (s *State) JobById(id string, ws memdb.WatchSet) (*Job, error) {
if jobIdx.State == vagrant_server.Job_QUEUED { if jobIdx.State == vagrant_server.Job_QUEUED {
blocked, err = s.jobIsBlocked(memTxn, jobIdx, ws) blocked, err = s.jobIsBlocked(memTxn, jobIdx, ws)
if err != nil { if err != nil {
return nil, err return nil, lookupErrorToStatus("job", err)
} }
} }
var job *vagrant_server.Job job, err := s.jobById(jobIdx.Id)
err = s.db.View(func(dbTxn *bolt.Tx) error { if err != nil {
job, err = s.jobById(dbTxn, jobIdx.Id) return nil, lookupErrorToStatus("job", err)
return err }
})
result := jobIdx.Job(job) result := jobIdx.Job(job)
result.Blocked = blocked result.Blocked = blocked
return result, err return result, nil
} }
// JobAssignForRunner will wait for and assign a job to a specific runner. // JobAssignForRunner will wait for and assign a job to a specific runner.
@ -266,10 +410,13 @@ RETRY_ASSIGN:
defer txn.Abort() defer txn.Abort()
// Turn our runner into a runner record so we can more efficiently assign // Turn our runner into a runner record so we can more efficiently assign
runnerRec := newRunnerRecord(r) runnerRec, err := s.RunnerFromProto(r)
if err != nil {
return nil, fmt.Errorf("runner lookup failed: %w", err)
}
// candidateQuery finds candidate jobs to assign. // candidateQuery finds candidate jobs to assign.
type candidateFunc func(*memdb.Txn, memdb.WatchSet, *runnerRecord) (*jobIndex, error) type candidateFunc func(*memdb.Txn, memdb.WatchSet, *Runner) (*jobIndex, error)
candidateQuery := []candidateFunc{ candidateQuery := []candidateFunc{
s.jobCandidateById, s.jobCandidateById,
s.jobCandidateAny, s.jobCandidateAny,
@ -409,7 +556,7 @@ func (s *State) JobAck(id string, ack bool) (*Job, error) {
// Get the job // Get the job
raw, err := txn.First(jobTableName, jobIdIndexName, id) raw, err := txn.First(jobTableName, jobIdIndexName, id)
if err != nil { if err != nil {
return nil, err return nil, lookupErrorToStatus("job", err)
} }
if raw == nil { if raw == nil {
return nil, status.Errorf(codes.NotFound, "job not found: %s", id) return nil, status.Errorf(codes.NotFound, "job not found: %s", id)
@ -443,7 +590,7 @@ func (s *State) JobAck(id string, ack bool) (*Job, error) {
return nil return nil
}) })
if err != nil { if err != nil {
return nil, err return nil, lookupErrorToStatus("job", err)
} }
// Cancel our timer // Cancel our timer
@ -467,13 +614,13 @@ func (s *State) JobAck(id string, ack bool) (*Job, error) {
// Insert to update // Insert to update
if err := txn.Insert(jobTableName, job); err != nil { if err := txn.Insert(jobTableName, job); err != nil {
return nil, err return nil, saveErrorToStatus("job", err)
} }
// Update our assigned state if we nacked // Update our assigned state if we nacked
if !ack { if !ack {
if err := s.jobAssignedSet(txn, job, false); err != nil { if err := s.jobAssignedSet(txn, job, false); err != nil {
return nil, err return nil, saveErrorToStatus("job", err)
} }
} }
@ -491,7 +638,7 @@ func (s *State) JobComplete(id string, result *vagrant_server.Job_Result, cerr e
// Get the job // Get the job
raw, err := txn.First(jobTableName, jobIdIndexName, id) raw, err := txn.First(jobTableName, jobIdIndexName, id)
if err != nil { if err != nil {
return err return lookupErrorToStatus("job", err)
} }
if raw == nil { if raw == nil {
return status.Errorf(codes.NotFound, "job not found: %s", id) return status.Errorf(codes.NotFound, "job not found: %s", id)
@ -500,7 +647,7 @@ func (s *State) JobComplete(id string, result *vagrant_server.Job_Result, cerr e
// Update our assigned state // Update our assigned state
if err := s.jobAssignedSet(txn, job, false); err != nil { if err := s.jobAssignedSet(txn, job, false); err != nil {
return err return saveErrorToStatus("job", err)
} }
// If the job is not in the assigned state, then this is an error. // If the job is not in the assigned state, then this is an error.
@ -528,7 +675,7 @@ func (s *State) JobComplete(id string, result *vagrant_server.Job_Result, cerr e
return nil return nil
}) })
if err != nil { if err != nil {
return err return saveErrorToStatus("job", err)
} }
// End the job // End the job
@ -536,7 +683,7 @@ func (s *State) JobComplete(id string, result *vagrant_server.Job_Result, cerr e
// Insert to update // Insert to update
if err := txn.Insert(jobTableName, job); err != nil { if err := txn.Insert(jobTableName, job); err != nil {
return err return saveErrorToStatus("job", err)
} }
txn.Commit() txn.Commit()
@ -553,7 +700,7 @@ func (s *State) JobCancel(id string, force bool) error {
// Get the job // Get the job
raw, err := txn.First(jobTableName, jobIdIndexName, id) raw, err := txn.First(jobTableName, jobIdIndexName, id)
if err != nil { if err != nil {
return err return lookupErrorToStatus("job", err)
} }
if raw == nil { if raw == nil {
return status.Errorf(codes.NotFound, "job not found: %s", id) return status.Errorf(codes.NotFound, "job not found: %s", id)
@ -561,7 +708,7 @@ func (s *State) JobCancel(id string, force bool) error {
job := raw.(*jobIndex) job := raw.(*jobIndex)
if err := s.jobCancel(txn, job, force); err != nil { if err := s.jobCancel(txn, job, force); err != nil {
return err return saveErrorToStatus("job", err)
} }
txn.Commit() txn.Commit()
@ -717,11 +864,8 @@ func (s *State) JobExpire(id string) error {
// deregister between this returning true and queueing, the job may still // deregister between this returning true and queueing, the job may still
// sit in a queue indefinitely. // sit in a queue indefinitely.
func (s *State) JobIsAssignable(ctx context.Context, jobpb *vagrant_server.Job) (bool, error) { func (s *State) JobIsAssignable(ctx context.Context, jobpb *vagrant_server.Job) (bool, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
// If we have no runners, we cannot be assigned // If we have no runners, we cannot be assigned
empty, err := s.runnerEmpty(memTxn) empty, err := s.runnerEmpty()
if err != nil { if err != nil {
return false, err return false, err
} }
@ -730,89 +874,81 @@ func (s *State) JobIsAssignable(ctx context.Context, jobpb *vagrant_server.Job)
} }
// If we have a special targeting constraint, that has to be met // If we have a special targeting constraint, that has to be met
var iter memdb.ResultIterator tx := s.db.Model(&Runner{})
var targetCheck func(*vagrant_server.Runner) (bool, error)
switch v := jobpb.TargetRunner.Target.(type) { switch v := jobpb.TargetRunner.Target.(type) {
case *vagrant_server.Ref_Runner_Any: case *vagrant_server.Ref_Runner_Any:
// We need a special target check that disallows by ID only tx = tx.Where("by_id_only = ?", false)
targetCheck = func(r *vagrant_server.Runner) (bool, error) {
return !r.ByIdOnly, nil
}
iter, err = memTxn.LowerBound(runnerTableName, runnerIdIndexName, "")
case *vagrant_server.Ref_Runner_Id: case *vagrant_server.Ref_Runner_Id:
iter, err = memTxn.Get(runnerTableName, runnerIdIndexName, v.Id.Id) tx = tx.Where("rid = ?", v.Id.Id)
default: default:
return false, fmt.Errorf("unknown runner target value: %#v", jobpb.TargetRunner.Target) return false, fmt.Errorf("unknown runner target value: %#v", jobpb.TargetRunner.Target)
} }
if err != nil {
return false, err var c int64
result := tx.Count(&c)
if result.Error != nil {
return false, result.Error
} }
for { return c > 0, result.Error
raw := iter.Next()
if raw == nil {
// We're out of candidates and we found none.
return false, nil
}
runner := raw.(*runnerRecord)
// Check our target-specific check
if targetCheck != nil {
check, err := targetCheck(runner.Runner)
if err != nil {
return false, err
}
if !check {
continue
}
}
// This works!
return true, nil
}
} }
// jobIndexInit initializes the config index from persisted data. // jobIndexInit initializes the config index from persisted data.
func (s *State) jobIndexInit(dbTxn *bolt.Tx, memTxn *memdb.Txn) error { func (s *State) jobIndexInit(memTxn *memdb.Txn) error {
bucket := dbTxn.Bucket(jobBucket) var jobs []InternalJob
return bucket.ForEach(func(k, v []byte) error {
var value vagrant_server.Job
if err := proto.Unmarshal(v, &value); err != nil {
return err
}
idx, err := s.jobIndexSet(memTxn, k, &value) // Get all jobs which are not completed
result := s.search().
Where(&InternalJob{State: JOB_STATE_UNKNOWN}).
Or(&InternalJob{State: JOB_STATE_QUEUED}).
Or(&InternalJob{State: JOB_STATE_WAITING}).
Or(&InternalJob{State: JOB_STATE_RUNNING}).
Find(&jobs)
if result.Error != nil {
return result.Error
}
// Load all incomplete jobs into memory
for _, j := range jobs {
job := j.ToProto()
if j.Jid == nil {
continue
}
idx, err := s.jobIndexSet(memTxn, []byte(*j.Jid), job)
if err != nil { if err != nil {
return err return err
} }
// If the job was running or waiting, set it as assigned. // If the job was running or waiting, set it as assigned.
if value.State == vagrant_server.Job_RUNNING || value.State == vagrant_server.Job_WAITING { if j.State == JOB_STATE_WAITING || j.State == JOB_STATE_RUNNING {
if err := s.jobAssignedSet(memTxn, idx, true); err != nil { if err = s.jobAssignedSet(memTxn, idx, true); err != nil {
return err return err
} }
} }
}
return nil return nil
})
} }
// jobIndexSet writes an index record for a single job. // jobIndexSet writes an index record for a single job.
func (s *State) jobIndexSet(txn *memdb.Txn, id []byte, jobpb *vagrant_server.Job) (*jobIndex, error) { func (s *State) jobIndexSet(txn *memdb.Txn, id []byte, jobpb *vagrant_server.Job) (*jobIndex, error) {
rec := &jobIndex{ rec := &jobIndex{
Id: jobpb.Id, Id: jobpb.Id,
State: jobpb.State, State: jobpb.State,
Basis: jobpb.Basis, OpType: reflect.TypeOf(jobpb.Operation),
Project: jobpb.Project, }
Target: jobpb.Target,
OpType: reflect.TypeOf(jobpb.Operation), switch v := jobpb.Scope.(type) {
case *vagrant_server.Job_Basis:
rec.Scope = v.Basis
case *vagrant_server.Job_Project:
rec.Scope = v.Project
case *vagrant_server.Job_Target:
rec.Scope = v.Target
} }
// Target // Target
if jobpb.TargetRunner == nil { if jobpb.TargetRunner == nil || jobpb.TargetRunner.Target == nil {
return nil, fmt.Errorf("job target runner must be set") return nil, fmt.Errorf("job target runner must be set")
} }
switch v := jobpb.TargetRunner.Target.(type) { switch v := jobpb.TargetRunner.Target.(type) {
@ -823,7 +959,7 @@ func (s *State) jobIndexSet(txn *memdb.Txn, id []byte, jobpb *vagrant_server.Job
rec.TargetRunnerId = v.Id.Id rec.TargetRunnerId = v.Id.Id
default: default:
return nil, fmt.Errorf("unknown runner target value: %#v", jobpb.TargetRunner.Target) return nil, fmt.Errorf("unknown runner target value: %#v", jobpb.TargetRunner)
} }
// Timestamps // Timestamps
@ -881,21 +1017,36 @@ func (s *State) jobIndexSet(txn *memdb.Txn, id []byte, jobpb *vagrant_server.Job
return rec, txn.Insert(jobTableName, rec) return rec, txn.Insert(jobTableName, rec)
} }
func (s *State) jobCreate(dbTxn *bolt.Tx, memTxn *memdb.Txn, jobpb *vagrant_server.Job) error { func (s *State) jobCreate(memTxn *memdb.Txn, jobpb *vagrant_server.Job) error {
// Setup our initial job state // Setup our initial job state
var err error var err error
jobpb.State = vagrant_server.Job_QUEUED jobpb.State = vagrant_server.Job_QUEUED
jobpb.QueueTime = timestamppb.New(time.Now()) jobpb.QueueTime = timestamppb.New(time.Now())
id := []byte(jobpb.Id) // Convert the job proto into a record
job, err := s.InternalJobFromProto(jobpb)
// Insert into bolt if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
if err := dbPut(dbTxn.Bucket(jobBucket), id, jobpb); err != nil {
return err return err
} }
// Insert into the DB if err != nil {
_, err = s.jobIndexSet(memTxn, id, jobpb) job = &InternalJob{}
}
if err = s.softDecode(jobpb, job); err != nil {
return err
}
// Save the record into the db
result := s.db.Create(job)
if result.Error != nil {
return result.Error
}
id := []byte(*job.Jid)
// Insert into the in memory db
_, err = s.jobIndexSet(memTxn, id, job.ToProto())
s.pruneMu.Lock() s.pruneMu.Lock()
defer s.pruneMu.Unlock() defer s.pruneMu.Unlock()
@ -904,39 +1055,52 @@ func (s *State) jobCreate(dbTxn *bolt.Tx, memTxn *memdb.Txn, jobpb *vagrant_serv
return err return err
} }
func (s *State) jobById(dbTxn *bolt.Tx, id string) (*vagrant_server.Job, error) { func (s *State) jobById(sid string) (*vagrant_server.Job, error) {
var result vagrant_server.Job job, err := s.InternalJobFromProto(&vagrant_server.Job{Id: sid})
b := dbTxn.Bucket(jobBucket) if err != nil {
return &result, dbGet(b, []byte(id), &result) return nil, err
}
return job.ToProto(), nil
} }
func (s *State) jobReadAndUpdate(id string, f func(*vagrant_server.Job) error) (*vagrant_server.Job, error) { func (s *State) jobReadAndUpdate(id string, f func(*vagrant_server.Job) error) (*vagrant_server.Job, error) {
var result *vagrant_server.Job
var err error var err error
return result, s.db.Update(func(dbTxn *bolt.Tx) error {
result, err = s.jobById(dbTxn, id)
if err != nil {
return err
}
// Modify j, err := s.jobById(id)
if err := f(result); err != nil { if err != nil {
return err return nil, err
} }
// Commit if err := f(j); err != nil {
return dbPut(dbTxn.Bucket(jobBucket), []byte(id), result) return nil, err
}) }
ij, err := s.InternalJobFromProto(j)
if err != nil {
return nil, err
}
if err := s.softDecode(j, ij); err != nil {
return nil, err
}
result := s.db.Save(ij)
if result.Error != nil {
return nil, result.Error
}
return ij.ToProto(), nil
} }
// jobCandidateById returns the most promising candidate job to assign // jobCandidateById returns the most promising candidate job to assign
// that is targeting a specific runner by ID. // that is targeting a specific runner by ID.
func (s *State) jobCandidateById(memTxn *memdb.Txn, ws memdb.WatchSet, r *runnerRecord) (*jobIndex, error) { func (s *State) jobCandidateById(memTxn *memdb.Txn, ws memdb.WatchSet, r *Runner) (*jobIndex, error) {
iter, err := memTxn.LowerBound( iter, err := memTxn.LowerBound(
jobTableName, jobTableName,
jobTargetIdIndexName, jobTargetIdIndexName,
vagrant_server.Job_QUEUED, vagrant_server.Job_QUEUED,
r.Id, *r.Rid,
time.Unix(0, 0), time.Unix(0, 0),
) )
if err != nil { if err != nil {
@ -950,7 +1114,7 @@ func (s *State) jobCandidateById(memTxn *memdb.Txn, ws memdb.WatchSet, r *runner
} }
job := raw.(*jobIndex) job := raw.(*jobIndex)
if job.State != vagrant_server.Job_QUEUED || job.TargetRunnerId != r.Id { if job.State != vagrant_server.Job_QUEUED || job.TargetRunnerId != *r.Rid {
continue continue
} }
@ -968,7 +1132,7 @@ func (s *State) jobCandidateById(memTxn *memdb.Txn, ws memdb.WatchSet, r *runner
} }
// jobCandidateAny returns the first candidate job that targets any runner. // jobCandidateAny returns the first candidate job that targets any runner.
func (s *State) jobCandidateAny(memTxn *memdb.Txn, ws memdb.WatchSet, r *runnerRecord) (*jobIndex, error) { func (s *State) jobCandidateAny(memTxn *memdb.Txn, ws memdb.WatchSet, r *Runner) (*jobIndex, error) {
iter, err := memTxn.LowerBound( iter, err := memTxn.LowerBound(
jobTableName, jobTableName,
jobQueueTimeIndexName, jobQueueTimeIndexName,
@ -1020,36 +1184,20 @@ func (s *State) jobsPruneOld(memTxn *memdb.Txn, max int) (int, error) {
} }
func (s *State) JobsDBPruneOld(max int) (int, error) { func (s *State) JobsDBPruneOld(max int) (int, error) {
cnt := dbCount(s.db, jobTableName) var jobs []InternalJob
toDelete := cnt - max result := s.db.Select("id").Order("queue_time asc").Offset(max).Find(&jobs)
var deleted int if result.Error != nil {
return 0, result.Error
}
deleted := len(jobs)
if deleted < 1 {
return deleted, nil
}
result = s.db.Unscoped().Delete(jobs)
if result.Error != nil {
return 0, result.Error
}
// Prune jobs from boltDB
s.db.Update(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte(jobTableName))
cur := bucket.Cursor()
key, _ := cur.First()
for {
if key == nil {
break
}
// otherwise, prune this job! Once we've pruned enough jobs to get back
// to the maximum, we stop pruning.
toDelete--
err := bucket.Delete(key)
if err != nil {
return err
}
deleted++
if toDelete <= 0 {
break
}
key, _ = cur.Next()
}
return nil
})
return deleted, nil return deleted, nil
} }

View File

@ -59,9 +59,7 @@ func jobAssignedSchema() *memdb.TableSchema {
} }
type jobAssignedIndex struct { type jobAssignedIndex struct {
Basis string ResourceId string
Project string
Machine string
} }
// jobIsBlocked will return true if the given job is currently blocked because // jobIsBlocked will return true if the given job is currently blocked because
@ -80,7 +78,7 @@ func (s *State) jobIsBlocked(memTxn *memdb.Txn, idx *jobIndex, ws memdb.WatchSet
watchCh, value, err := memTxn.FirstWatch( watchCh, value, err := memTxn.FirstWatch(
jobAssignedTableName, jobAssignedTableName,
jobAssignedIdIndexName, jobAssignedIdIndexName,
s.jobAssignedIdxArgs(idx)..., idx.Scope.GetResourceId(),
) )
if err != nil { if err != nil {
return false, err return false, err
@ -100,11 +98,8 @@ func (s *State) jobAssignedSet(memTxn *memdb.Txn, idx *jobIndex, assigned bool)
return nil return nil
} }
args := s.jobAssignedIdxArgs(idx)
rec := &jobAssignedIndex{ rec := &jobAssignedIndex{
Basis: args[0].(string), ResourceId: idx.Scope.GetResourceId(),
Project: args[1].(string),
Machine: args[2].(string),
} }
if assigned { if assigned {
@ -113,16 +108,3 @@ func (s *State) jobAssignedSet(memTxn *memdb.Txn, idx *jobIndex, assigned bool)
return memTxn.Delete(jobAssignedTableName, rec) return memTxn.Delete(jobAssignedTableName, rec)
} }
func (s *State) jobAssignedIdxArgs(idx *jobIndex) []interface{} {
if idx.Target != nil {
return []interface{}{
idx.Target.Project.Basis.ResourceId, idx.Target.Project.ResourceId, idx.Target.ResourceId,
}
} else if idx.Project != nil {
return []interface{}{
idx.Project.Basis.ResourceId, idx.Project.ResourceId, "",
}
}
return []interface{}{idx.Basis.ResourceId, "", ""}
}

View File

@ -11,7 +11,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/hashicorp/go-memdb" "github.com/hashicorp/go-memdb"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk" // "github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
serverptypes "github.com/hashicorp/vagrant/internal/server/ptypes" serverptypes "github.com/hashicorp/vagrant/internal/server/ptypes"
) )
@ -23,9 +23,15 @@ func TestJobAssign(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))) })))
// Assign it, we should get this build // Assign it, we should get this build
@ -53,11 +59,14 @@ func TestJobAssign(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Project: &vagrant_plugin_sdk.Ref_Project{ Scope: &vagrant_server.Job_Project{
ResourceId: "project1", Project: projRef,
}, },
}))) })))
@ -90,10 +99,10 @@ func TestJobAssign(t *testing.T) {
} }
// Insert another job // Insert another job
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "B", Id: "B",
Project: &vagrant_plugin_sdk.Ref_Project{ Scope: &vagrant_server.Job_Project{
ResourceId: "project2", Project: projRef,
}, },
}))) })))
@ -268,13 +277,22 @@ func TestJobAssign(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create two builds slightly apart // Create two builds slightly apart
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))) })))
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "B", Id: "B",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))) })))
// Assign it, we should get build A then B // Assign it, we should get build A then B
@ -304,9 +322,16 @@ func TestJobAssign(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
testRunner(t, s, &vagrant_server.Runner{Id: "R_B"})
// Create a build by ID // Create a build by ID
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
TargetRunner: &vagrant_server.Ref_Runner{ TargetRunner: &vagrant_server.Ref_Runner{
Target: &vagrant_server.Ref_Runner_Id{ Target: &vagrant_server.Ref_Runner_Id{
Id: &vagrant_server.Ref_RunnerId{ Id: &vagrant_server.Ref_RunnerId{
@ -316,11 +341,11 @@ func TestJobAssign(t *testing.T) {
}, },
}))) })))
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "B", Id: "B",
}))) })))
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "C", Id: "C",
}))) })))
@ -354,9 +379,16 @@ func TestJobAssign(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_B"})
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build by ID // Create a build by ID
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
TargetRunner: &vagrant_server.Ref_Runner{ TargetRunner: &vagrant_server.Ref_Runner{
Target: &vagrant_server.Ref_Runner_Id{ Target: &vagrant_server.Ref_Runner_Id{
Id: &vagrant_server.Ref_RunnerId{ Id: &vagrant_server.Ref_RunnerId{
@ -393,12 +425,17 @@ func TestJobAssign(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
r := &vagrant_server.Runner{Id: "R_A", ByIdOnly: true} r := &vagrant_server.Runner{Id: "R_A", ByIdOnly: true}
testRunner(t, s, r)
// Create a build // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))) })))
// Should block because none direct assign // Should block because none direct assign
@ -410,8 +447,11 @@ func TestJobAssign(t *testing.T) {
require.Equal(ctx.Err(), err) require.Equal(ctx.Err(), err)
// Create a target // Create a target
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "B", Id: "B",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
TargetRunner: &vagrant_server.Ref_Runner{ TargetRunner: &vagrant_server.Ref_Runner{
Target: &vagrant_server.Ref_Runner_Id{ Target: &vagrant_server.Ref_Runner_Id{
Id: &vagrant_server.Ref_RunnerId{ Id: &vagrant_server.Ref_RunnerId{
@ -436,9 +476,15 @@ func TestJobAck(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))) })))
// Assign it, we should get this build // Assign it, we should get this build
@ -466,9 +512,15 @@ func TestJobAck(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))) })))
// Assign it, we should get this build // Assign it, we should get this build
@ -501,9 +553,15 @@ func TestJobAck(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))) })))
// Assign it, we should get this build // Assign it, we should get this build
@ -533,9 +591,15 @@ func TestJobComplete(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))) })))
// Assign it, we should get this build // Assign it, we should get this build
@ -550,7 +614,7 @@ func TestJobComplete(t *testing.T) {
// Complete it // Complete it
require.NoError(s.JobComplete(job.Id, &vagrant_server.Job_Result{ require.NoError(s.JobComplete(job.Id, &vagrant_server.Job_Result{
Run: &vagrant_server.Job_RunResult{}, Run: &vagrant_server.Job_CommandResult{},
}, nil)) }, nil))
// Verify it is changed // Verify it is changed
@ -568,9 +632,15 @@ func TestJobComplete(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))) })))
// Assign it, we should get this build // Assign it, we should get this build
@ -606,9 +676,14 @@ func TestJobIsAssignable(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
// Create a build // Create a build
result, err := s.JobIsAssignable(ctx, serverptypes.TestJobNew(t, &vagrant_server.Job{ result, err := s.JobIsAssignable(ctx, testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})) }))
require.NoError(err) require.NoError(err)
require.False(result) require.False(result)
@ -620,13 +695,15 @@ func TestJobIsAssignable(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
// Register a runner testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
require.NoError(s.RunnerCreate(serverptypes.TestRunner(t, nil)))
// Should be assignable // Should be assignable
result, err := s.JobIsAssignable(ctx, serverptypes.TestJobNew(t, &vagrant_server.Job{ result, err := s.JobIsAssignable(ctx, testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
TargetRunner: &vagrant_server.Ref_Runner{ TargetRunner: &vagrant_server.Ref_Runner{
Target: &vagrant_server.Ref_Runner_Any{ Target: &vagrant_server.Ref_Runner_Any{
Any: &vagrant_server.Ref_RunnerAny{}, Any: &vagrant_server.Ref_RunnerAny{},
@ -643,15 +720,15 @@ func TestJobIsAssignable(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
// Register a runner testRunner(t, s, &vagrant_server.Runner{Id: "R_A", ByIdOnly: true})
require.NoError(s.RunnerCreate(serverptypes.TestRunner(t, &vagrant_server.Runner{
ByIdOnly: true,
})))
// Should be assignable // Should be assignable
result, err := s.JobIsAssignable(ctx, serverptypes.TestJobNew(t, &vagrant_server.Job{ result, err := s.JobIsAssignable(ctx, testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
TargetRunner: &vagrant_server.Ref_Runner{ TargetRunner: &vagrant_server.Ref_Runner{
Target: &vagrant_server.Ref_Runner_Any{ Target: &vagrant_server.Ref_Runner_Any{
Any: &vagrant_server.Ref_RunnerAny{}, Any: &vagrant_server.Ref_RunnerAny{},
@ -668,13 +745,15 @@ func TestJobIsAssignable(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
// Register a runner testRunner(t, s, &vagrant_server.Runner{Id: "R_B"})
require.NoError(s.RunnerCreate(serverptypes.TestRunner(t, nil)))
// Should be assignable // Should be assignable
result, err := s.JobIsAssignable(ctx, serverptypes.TestJobNew(t, &vagrant_server.Job{ result, err := s.JobIsAssignable(ctx, testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
TargetRunner: &vagrant_server.Ref_Runner{ TargetRunner: &vagrant_server.Ref_Runner{
Target: &vagrant_server.Ref_Runner_Id{ Target: &vagrant_server.Ref_Runner_Id{
Id: &vagrant_server.Ref_RunnerId{ Id: &vagrant_server.Ref_RunnerId{
@ -693,18 +772,19 @@ func TestJobIsAssignable(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
// Register a runner testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
runner := serverptypes.TestRunner(t, nil)
require.NoError(s.RunnerCreate(runner))
// Should be assignable // Should be assignable
result, err := s.JobIsAssignable(ctx, serverptypes.TestJobNew(t, &vagrant_server.Job{ result, err := s.JobIsAssignable(ctx, testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
TargetRunner: &vagrant_server.Ref_Runner{ TargetRunner: &vagrant_server.Ref_Runner{
Target: &vagrant_server.Ref_Runner_Id{ Target: &vagrant_server.Ref_Runner_Id{
Id: &vagrant_server.Ref_RunnerId{ Id: &vagrant_server.Ref_RunnerId{
Id: runner.Id, Id: "R_A",
}, },
}, },
}, },
@ -720,10 +800,14 @@ func TestJobCancel(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
// Create a build // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))) })))
// Cancel it // Cancel it
@ -742,10 +826,15 @@ func TestJobCancel(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))) })))
// Assign it, we should get this build // Assign it, we should get this build
@ -770,10 +859,15 @@ func TestJobCancel(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))) })))
// Assign it, we should get this build // Assign it, we should get this build
@ -798,11 +892,16 @@ func TestJobCancel(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Operation: &vagrant_server.Job_Run{}, Scope: &vagrant_server.Job_Project{
Project: projRef,
},
Operation: &vagrant_server.Job_Command{},
}))) })))
// Assign it, we should get this build // Assign it, we should get this build
@ -822,9 +921,12 @@ func TestJobCancel(t *testing.T) {
require.NotEmpty(job.CancelTime) require.NotEmpty(job.CancelTime)
// Create a another job // Create a another job
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "B", Id: "B",
Operation: &vagrant_server.Job_Run{}, Scope: &vagrant_server.Job_Project{
Project: projRef,
},
Operation: &vagrant_server.Job_Command{},
}))) })))
ws := memdb.NewWatchSet() ws := memdb.NewWatchSet()
@ -843,10 +945,15 @@ func TestJobCancel(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))) })))
// Assign it, we should get this build // Assign it, we should get this build
@ -880,6 +987,8 @@ func TestJobHeartbeat(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Set a short timeout // Set a short timeout
old := jobHeartbeatTimeout old := jobHeartbeatTimeout
@ -887,8 +996,11 @@ func TestJobHeartbeat(t *testing.T) {
jobHeartbeatTimeout = 5 * time.Millisecond jobHeartbeatTimeout = 5 * time.Millisecond
// Create a build // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))) })))
// Assign it, we should get this build // Assign it, we should get this build
@ -902,6 +1014,8 @@ func TestJobHeartbeat(t *testing.T) {
_, err = s.JobAck(job.Id, true) _, err = s.JobAck(job.Id, true)
require.NoError(err) require.NoError(err)
time.Sleep(1 * time.Second)
// Should time out // Should time out
require.Eventually(func() bool { require.Eventually(func() bool {
// Verify it is canceled // Verify it is canceled
@ -921,10 +1035,15 @@ func TestJobHeartbeat(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))) })))
// Assign it, we should get this build // Assign it, we should get this build
@ -985,9 +1104,15 @@ func TestJobHeartbeat(t *testing.T) {
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))) })))
// Assign it, we should get this build // Assign it, we should get this build
@ -1026,7 +1151,7 @@ func TestJobHeartbeat(t *testing.T) {
}() }()
// Sleep for a bit // Sleep for a bit
time.Sleep(1 * time.Second) time.Sleep(10 * time.Millisecond)
// Verify it is running // Verify it is running
job, err = s.JobById("A", nil) job, err = s.JobById("A", nil)
@ -1036,6 +1161,10 @@ func TestJobHeartbeat(t *testing.T) {
// Stop heartbeating // Stop heartbeating
cancel() cancel()
// Pause before check. We encounter the database being
// scrubbed otherwise (TODO: fixme)
time.Sleep(1 * time.Second)
// Should time out // Should time out
require.Eventually(func() bool { require.Eventually(func() bool {
// Verify it is canceled // Verify it is canceled
@ -1045,72 +1174,77 @@ func TestJobHeartbeat(t *testing.T) {
}, 1*time.Second, 10*time.Millisecond) }, 1*time.Second, 10*time.Millisecond)
}) })
t.Run("times out if running state loaded on restart", func(t *testing.T) { // t.Run("times out if running state loaded on restart", func(t *testing.T) {
require := require.New(t) // require := require.New(t)
// Set a short timeout // // Set a short timeout
old := jobHeartbeatTimeout // old := jobHeartbeatTimeout
defer func() { jobHeartbeatTimeout = old }() // defer func() { jobHeartbeatTimeout = old }()
jobHeartbeatTimeout = 250 * time.Millisecond // jobHeartbeatTimeout = 250 * time.Millisecond
s := TestState(t) // s := TestState(t)
defer s.Close() // defer s.Close()
// projRef := testProject(t, s)
// testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build // // Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{ // require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A", // Id: "A",
}))) // Scope: &vagrant_server.Job_Project{
// Project: projRef,
// },
// })))
// Assign it, we should get this build // // Assign it, we should get this build
job, err := s.JobAssignForRunner(context.Background(), &vagrant_server.Runner{Id: "R_A"}) // job, err := s.JobAssignForRunner(context.Background(), &vagrant_server.Runner{Id: "R_A"})
require.NoError(err) // require.NoError(err)
require.NotNil(job) // require.NotNil(job)
require.Equal("A", job.Id) // require.Equal("A", job.Id)
require.Equal(vagrant_server.Job_WAITING, job.State) // require.Equal(vagrant_server.Job_WAITING, job.State)
// Ack it // // Ack it
_, err = s.JobAck(job.Id, true) // _, err = s.JobAck(job.Id, true)
require.NoError(err) // require.NoError(err)
// Start heartbeating // // Start heartbeating
ctx, cancel := context.WithCancel(context.Background()) // ctx, cancel := context.WithCancel(context.Background())
doneCh := make(chan struct{}) // doneCh := make(chan struct{})
defer func() { // defer func() {
cancel() // cancel()
<-doneCh // <-doneCh
}() // }()
go func(s *State) { // go func(s *State) {
defer close(doneCh) // defer close(doneCh)
tick := time.NewTicker(20 * time.Millisecond) // tick := time.NewTicker(20 * time.Millisecond)
defer tick.Stop() // defer tick.Stop()
for { // for {
select { // select {
case <-tick.C: // case <-tick.C:
s.JobHeartbeat(job.Id) // s.JobHeartbeat(job.Id)
case <-ctx.Done(): // case <-ctx.Done():
return // return
} // }
} // }
}(s) // }(s)
// Reinit the state as if we crashed // Reinit the state as if we crashed
s = TestStateReinit(t, s) // s = TestStateReinit(t, s)
defer s.Close() // defer s.Close()
// Verify it exists // // Verify it exists
job, err = s.JobById("A", nil) // job, err = s.JobById("A", nil)
require.NoError(err) // require.NoError(err)
require.Equal(vagrant_server.Job_RUNNING, job.Job.State) // require.Equal(vagrant_server.Job_RUNNING, job.Job.State)
// Should time out // // Should time out
require.Eventually(func() bool { // require.Eventually(func() bool {
// Verify it is canceled // // Verify it is canceled
job, err = s.JobById("A", nil) // job, err = s.JobById("A", nil)
require.NoError(err) // require.NoError(err)
return job.Job.State == vagrant_server.Job_ERROR // return job.Job.State == vagrant_server.Job_ERROR
}, 2*time.Second, 10*time.Millisecond) // }, 2*time.Second, 10*time.Millisecond)
}) // })
} }

View File

@ -0,0 +1,75 @@
package state
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
// MetadataSet is a simple map with a string key type
// and string value type. It is stored within the database
// as a JSON type so it can be queried.
type MetadataSet map[string]string
// User consumable data type name
func (m MetadataSet) GormDataType() string {
return datatypes.JSON{}.GormDataType()
}
// Driver consumable data type name
func (m MetadataSet) GormDBDataType(db *gorm.DB, field *schema.Field) string {
return datatypes.JSON{}.GormDBDataType(db, field)
}
// Unmarshals the store value back to original type
func (m MetadataSet) Scan(value interface{}) error {
v, ok := value.([]byte)
if !ok {
return fmt.Errorf("Failed to unmarshal JSON value: %v", value)
}
j := datatypes.JSON{}
err := j.UnmarshalJSON(v)
if err != nil {
return err
}
result := MetadataSet{}
err = json.Unmarshal(j, &result)
if err != nil {
return err
}
m = result
return nil
}
// Marshal the value for storage in the database
func (m MetadataSet) Value() (driver.Value, error) {
if len(m) < 1 {
return nil, nil
}
v, err := json.Marshal(m)
if err != nil {
return nil, err
}
return string(v), nil
}
// Convert the MetadataSet into a protobuf message
func (m MetadataSet) ToProto() *vagrant_plugin_sdk.Args_MetadataSet {
return &vagrant_plugin_sdk.Args_MetadataSet{
Metadata: map[string]string(m),
}
}
var (
_ sql.Scanner = (*MetadataSet)(nil)
_ driver.Valuer = (*MetadataSet)(nil)
_ schema.GormDataTypeInterface = (*ProtoValue)(nil)
_ migrator.GormDataTypeInterface = (*ProtoValue)(nil)
)

View File

@ -1,396 +1,342 @@
package state package state
import ( import (
"strings" "errors"
"google.golang.org/protobuf/proto"
"github.com/hashicorp/go-memdb"
bolt "go.etcd.io/bbolt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/go-ozzo/ozzo-validation/v4"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk" "github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
serverptypes "github.com/hashicorp/vagrant/internal/server/ptypes" "gorm.io/gorm"
) )
var projectBucket = []byte("project")
func init() { func init() {
dbBuckets = append(dbBuckets, projectBucket) models = append(models, &Project{})
dbIndexers = append(dbIndexers, (*State).projectIndexInit)
schemas = append(schemas, projectIndexSchema)
} }
// ProjectPut creates or updates the given project. type Project struct {
func (s *State) ProjectPut(p *vagrant_server.Project) error { gorm.Model
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) (err error) { Basis *Basis
return s.projectPut(dbTxn, memTxn, p) BasisID *uint `gorm:"uniqueIndex:idx_bname;not null" mapstructure:"-"`
}) Vagrantfile *Vagrantfile `mapstructure:"-"`
VagrantfileID *uint `mapstructure:"-"`
DataSource *ProtoValue
Jobs []*InternalJob `gorm:"polymorphic:Scope;"`
Metadata MetadataSet
Name *string `gorm:"uniqueIndex:idx_bname;not null"`
Path *string `gorm:"uniqueIndex;not null"`
RemoteEnabled bool
ResourceId *string `gorm:"<-:create;uniqueIndex;not null"`
Targets []*Target
}
if err == nil { func (p *Project) scope() interface{} {
memTxn.Commit() return p
}
// Set a public ID on the project before creating
func (p *Project) BeforeSave(tx *gorm.DB) error {
if p.ResourceId == nil {
if err := p.setId(); err != nil {
return err
}
} }
if err := p.Validate(tx); err != nil {
return err
}
func (s *State) ProjectFind(p *vagrant_server.Project) (*vagrant_server.Project, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
var result *vagrant_server.Project
err := s.db.View(func(dbTxn *bolt.Tx) error {
var err error
result, err = s.projectFind(dbTxn, memTxn, p)
return err return err
}) }
return result, err return nil
} }
// ProjectGet gets a project by reference. func (p *Project) Validate(tx *gorm.DB) error {
func (s *State) ProjectGet(ref *vagrant_plugin_sdk.Ref_Project) (*vagrant_server.Project, error) { err := validation.ValidateStruct(p,
memTxn := s.inmem.Txn(false) // validation.Field(&p.Basis, validation.Required),
defer memTxn.Abort() validation.Field(&p.Name,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Project{}).
Where(&Project{Name: p.Name, BasisID: p.BasisID}).
Not(&Project{Model: gorm.Model{ID: p.ID}}),
),
),
),
validation.Field(&p.Path,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Project{}).
Where(&Project{Path: p.Path, BasisID: p.BasisID}).
Not(&Project{Model: gorm.Model{ID: p.ID}}),
),
),
),
validation.Field(&p.ResourceId,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Project{}).
Where(&Project{ResourceId: p.ResourceId}).
Not(&Project{Model: gorm.Model{ID: p.ID}}),
),
),
),
)
var result *vagrant_server.Project if err != nil {
err := s.db.View(func(dbTxn *bolt.Tx) (err error) {
result, err = s.projectGet(dbTxn, memTxn, ref)
return err return err
}) }
return result, err return nil
} }
// ProjectDelete deletes a project by reference. This is a complete data func (p *Project) setId() error {
// delete. This will delete all operations associated with this project id, err := server.Id()
// as well.
func (s *State) ProjectDelete(ref *vagrant_plugin_sdk.Ref_Project) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) error {
// Now remove the project
return s.projectDelete(dbTxn, memTxn, ref)
})
if err == nil {
memTxn.Commit()
}
return err
}
// ProjectList returns the list of projects.
func (s *State) ProjectList() ([]*vagrant_plugin_sdk.Ref_Project, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
return s.projectList(memTxn)
}
func (s *State) projectFind(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
p *vagrant_server.Project,
) (*vagrant_server.Project, error) {
var match *projectIndexRecord
// Start with the resource id first
if p.ResourceId != "" {
if raw, err := memTxn.First(
projectIndexTableName,
projectIndexIdIndexName,
p.ResourceId,
); raw != nil && err == nil {
match = raw.(*projectIndexRecord)
}
}
// Try the name next
if p.Name != "" && match == nil {
if raw, err := memTxn.First(
projectIndexTableName,
projectIndexNameIndexName,
p.Name,
); raw != nil && err == nil {
match = raw.(*projectIndexRecord)
}
}
// And finally the path
if p.Path != "" && match == nil {
if raw, err := memTxn.First(
projectIndexTableName,
projectIndexPathIndexName,
p.Path,
); raw != nil && err == nil {
match = raw.(*projectIndexRecord)
}
}
if match == nil {
return nil, status.Errorf(codes.NotFound, "record not found for Project")
}
return s.projectGet(dbTxn, memTxn, &vagrant_plugin_sdk.Ref_Project{
ResourceId: match.Id,
})
}
func (s *State) projectPut(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
value *vagrant_server.Project,
) (err error) {
s.log.Trace("storing project", "project", value, "basis", value.Basis)
// Grab the stored project if it's available
existProject, err := s.projectFind(dbTxn, memTxn, value)
if err != nil { if err != nil {
// ensure value is nil to identify non-existence return err
existProject = nil
} }
p.ResourceId = &id
// Grab the basis associated to this project so it can be attached return nil
b, err := s.basisGet(dbTxn, memTxn, value.Basis)
if err != nil {
s.log.Error("failed to locate basis for project", "project", value,
"basis", value.Basis, "error", err)
return
}
// set a resource id if none set
if value.ResourceId == "" {
s.log.Trace("project has no resource id, assuming new project",
"project", value)
if value.ResourceId, err = s.newResourceId(); err != nil {
s.log.Error("failed to create resource id for project", "project", value,
"error", err)
return
}
}
s.log.Trace("storing project to db", "project", value)
id := s.projectId(value)
// Get the global bucket and write the value to it.
bkt := dbTxn.Bucket(projectBucket)
if err = dbPut(bkt, id, value); err != nil {
s.log.Error("failed to store project in db", "project", value, "error", err)
return
}
s.log.Trace("indexing project", "project", value)
// Create our index value and write that.
if err = s.projectIndexSet(memTxn, id, value); err != nil {
s.log.Error("failed to index project", "project", value, "error", err)
return
}
s.log.Trace("adding project to basis", "project", value, "basis", b)
nb := &serverptypes.Basis{Basis: b}
if nb.AddProject(value) {
s.log.Trace("project added to basis, updating basis", "basis", b)
if err = s.basisPut(dbTxn, memTxn, b); err != nil {
s.log.Error("failed to update basis", "basis", b, "error", err)
return
}
} else {
s.log.Trace("project already exists in basis", "project", value, "basis", b)
}
// Check if the project basis was changed
if existProject != nil && existProject.Basis.ResourceId != b.ResourceId {
s.log.Trace("project basis has changed, updating old basis", "project", value,
"old-basis", existProject.Basis, "new-basis", value.Basis)
ob, err := s.basisGet(dbTxn, memTxn, existProject.Basis)
if err != nil {
s.log.Warn("failed to locate old basis, ignoring", "project", value, "old-basis",
existProject.Basis, "error", err)
return nil
}
bt := &serverptypes.Basis{Basis: ob}
if bt.DeleteProject(value) {
s.log.Trace("project deleted from old basis, updating basis", "project", value,
"old-basis", ob)
if err := s.basisPut(dbTxn, memTxn, ob); err != nil {
s.log.Error("failed to updated old basis for project removal", "project", value,
"old-basis", ob, "error", err)
}
}
}
return
} }
func (s *State) projectGet( // Convert project to reference protobuf message
dbTxn *bolt.Tx, func (p *Project) ToProtoRef() *vagrant_plugin_sdk.Ref_Project {
memTxn *memdb.Txn, if p == nil {
return nil
}
ref := vagrant_plugin_sdk.Ref_Project{}
err := decode(p, &ref)
if err != nil {
panic("failed to decode project to ref: " + err.Error())
}
return &ref
}
// Convert project to protobuf message
func (p *Project) ToProto() *vagrant_server.Project {
if p == nil {
return nil
}
var project vagrant_server.Project
err := decode(p, &project)
if err != nil {
panic("failed to decode project: " + err.Error())
}
// Manually include the vagrantfile since we force it to be ignored
if p.Vagrantfile != nil {
project.Configuration = p.Vagrantfile.ToProto()
}
return &project
}
// Load a Project from reference protobuf message.
func (s *State) ProjectFromProtoRef(
ref *vagrant_plugin_sdk.Ref_Project, ref *vagrant_plugin_sdk.Ref_Project,
) (*vagrant_server.Project, error) { ) (*Project, error) {
var result vagrant_server.Project if ref == nil {
b := dbTxn.Bucket(projectBucket) return nil, ErrEmptyProtoArgument
return &result, dbGet(b, s.projectIdByRef(ref), &result) }
if ref.ResourceId == "" {
return nil, gorm.ErrRecordNotFound
}
var project Project
result := s.search().First(&project,
&Project{ResourceId: &ref.ResourceId})
if result.Error != nil {
return nil, result.Error
}
return &project, nil
} }
func (s *State) projectList( func (s *State) ProjectFromProtoRefFuzzy(
memTxn *memdb.Txn, ref *vagrant_plugin_sdk.Ref_Project,
) ([]*vagrant_plugin_sdk.Ref_Project, error) { ) (*Project, error) {
iter, err := memTxn.Get(projectIndexTableName, projectIndexIdIndexName+"_prefix", "") project, err := s.ProjectFromProtoRef(ref)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
if ref.Basis == nil {
return nil, ErrMissingProtoParent
}
if ref.Name == "" && ref.Path == "" {
return nil, gorm.ErrRecordNotFound
}
project = &Project{}
query := &Project{}
if ref.Name != "" {
query.Name = &ref.Name
}
if ref.Path != "" {
query.Path = &ref.Path
}
result := s.search().
Joins("Basis", &Basis{ResourceId: &ref.Basis.ResourceId}).
Where(query).
First(project)
if result.Error != nil {
return nil, result.Error
}
return project, nil
}
// Load a Project from protobuf message.
func (s *State) ProjectFromProto(
p *vagrant_server.Project,
) (*Project, error) {
if p == nil {
return nil, ErrEmptyProtoArgument
}
project, err := s.ProjectFromProtoRef(
&vagrant_plugin_sdk.Ref_Project{
ResourceId: p.ResourceId,
},
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var result []*vagrant_plugin_sdk.Ref_Project return project, nil
for {
next := iter.Next()
if next == nil {
break
}
idx := next.(*projectIndexRecord)
result = append(result, &vagrant_plugin_sdk.Ref_Project{
ResourceId: idx.Id,
Name: idx.Name,
})
}
return result, nil
} }
func (s *State) projectDelete( func (s *State) ProjectFromProtoFuzzy(
dbTxn *bolt.Tx, p *vagrant_server.Project,
memTxn *memdb.Txn, ) (*Project, error) {
ref *vagrant_plugin_sdk.Ref_Project, if p == nil {
) (err error) { return nil, ErrEmptyProtoArgument
p, err := s.projectGet(dbTxn, memTxn, ref)
if err != nil {
return
} }
// Start with scrubbing all the machines project, err := s.ProjectFromProtoRefFuzzy(
for _, m := range p.Targets { &vagrant_plugin_sdk.Ref_Project{
if err = s.targetDelete(dbTxn, memTxn, m); err != nil { ResourceId: p.ResourceId,
return Basis: p.Basis,
} Name: p.Name,
} Path: p.Path,
// Grab the basis and remove the project
b, err := s.basisGet(dbTxn, memTxn, ref.Basis)
if err != nil {
return
}
bp := &serverptypes.Basis{Basis: b}
if bp.DeleteProjectRef(ref) {
err = s.basisPut(dbTxn, memTxn, b)
}
// Delete from bolt
if err := dbTxn.Bucket(projectBucket).Delete(s.projectId(p)); err != nil {
return err
}
// Delete from memdb
if err := memTxn.Delete(projectIndexTableName, s.newProjectIndexRecord(p)); err != nil {
return err
}
return
}
// projectIndexSet writes an index record for a single project.
func (s *State) projectIndexSet(txn *memdb.Txn, id []byte, value *vagrant_server.Project) error {
return txn.Insert(projectIndexTableName, s.newProjectIndexRecord(value))
}
// projectIndexInit initializes the project index from persisted data.
func (s *State) projectIndexInit(dbTxn *bolt.Tx, memTxn *memdb.Txn) error {
bucket := dbTxn.Bucket(projectBucket)
return bucket.ForEach(func(k, v []byte) error {
var value vagrant_server.Project
if err := proto.Unmarshal(v, &value); err != nil {
return err
}
if err := s.projectIndexSet(memTxn, k, &value); err != nil {
return err
}
return nil
})
}
func projectIndexSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: projectIndexTableName,
Indexes: map[string]*memdb.IndexSchema{
projectIndexIdIndexName: {
Name: projectIndexIdIndexName,
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Id",
Lowercase: false,
},
},
projectIndexNameIndexName: {
Name: projectIndexNameIndexName,
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Name",
Lowercase: true,
},
},
projectIndexPathIndexName: {
Name: projectIndexPathIndexName,
AllowMissing: true,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Path",
Lowercase: false,
},
},
}, },
)
if err != nil {
return nil, err
} }
return project, nil
} }
const ( // Get a project record using a reference protobuf message
projectIndexIdIndexName = "id" func (s *State) ProjectGet(
projectIndexNameIndexName = "name" p *vagrant_plugin_sdk.Ref_Project,
projectIndexPathIndexName = "path" ) (*vagrant_server.Project, error) {
projectIndexTableName = "project-index" project, err := s.ProjectFromProtoRef(p)
if err != nil {
return nil, lookupErrorToStatus("project", err)
}
return project.ToProto(), nil
}
// Store a Project
func (s *State) ProjectPut(
p *vagrant_server.Project,
) (*vagrant_server.Project, error) {
project, err := s.ProjectFromProto(p)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, lookupErrorToStatus("project", err)
}
// Make sure we don't have a nil value
if err != nil {
project = &Project{}
}
err = s.softDecode(p, project)
if err != nil {
return nil, saveErrorToStatus("project", err)
}
// If a configuration came over the wire, either create one to attach
// to the project or update the existing one
if p.Configuration != nil {
if project.Vagrantfile != nil {
project.Vagrantfile.UpdateFromProto(p.Configuration)
} else {
project.Vagrantfile = s.VagrantfileFromProto(p.Configuration)
}
}
result := s.db.Save(project)
if result.Error != nil {
return nil, saveErrorToStatus("project", result.Error)
}
return project.ToProto(), nil
}
// List all project records
func (s *State) ProjectList() ([]*vagrant_plugin_sdk.Ref_Project, error) {
var projects []Project
result := s.search().Find(&projects)
if result.Error != nil {
return nil, lookupErrorToStatus("projects", result.Error)
}
prefs := make([]*vagrant_plugin_sdk.Ref_Project, len(projects))
for i, prj := range projects {
prefs[i] = prj.ToProtoRef()
}
return prefs, nil
}
// Find a Project using a protobuf message
func (s *State) ProjectFind(p *vagrant_server.Project) (*vagrant_server.Project, error) {
project, err := s.ProjectFromProtoFuzzy(p)
if err != nil {
return nil, saveErrorToStatus("project", err)
}
return project.ToProto(), nil
}
// Delete a project
func (s *State) ProjectDelete(
p *vagrant_plugin_sdk.Ref_Project,
) error {
project, err := s.ProjectFromProtoRef(p)
// If the record was not found, we return with no error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return nil
}
// If an unexpected error was encountered, return it
if err != nil {
return deleteErrorToStatus("project", err)
}
result := s.db.Delete(project)
if result.Error != nil {
return deleteErrorToStatus("project", err)
}
return nil
}
var (
_ scope = (*Project)(nil)
) )
type projectIndexRecord struct {
Id string
Name string
Path string
}
func (s *State) newProjectIndexRecord(p *vagrant_server.Project) *projectIndexRecord {
return &projectIndexRecord{
Id: p.ResourceId,
Name: strings.ToLower(p.Name),
Path: p.Path,
}
}
func (s *State) newProjectIndexRecordByRef(ref *vagrant_plugin_sdk.Ref_Project) *projectIndexRecord {
return &projectIndexRecord{
Id: ref.ResourceId,
Name: strings.ToLower(ref.Name),
}
}
func (s *State) projectId(p *vagrant_server.Project) []byte {
return []byte(p.ResourceId)
}
func (s *State) projectIdByRef(ref *vagrant_plugin_sdk.Ref_Project) []byte {
if ref == nil {
return []byte{}
}
return []byte(ref.ResourceId)
}

View File

@ -34,23 +34,21 @@ func TestProject(t *testing.T) {
defer s.Close() defer s.Close()
basisRef := testBasis(t, s) basisRef := testBasis(t, s)
resourceId := "AbCdE"
// Set // Set
err := s.ProjectPut(serverptypes.TestProject(t, &vagrant_server.Project{ result, err := s.ProjectPut(serverptypes.TestProject(t, &vagrant_server.Project{
ResourceId: resourceId, Basis: basisRef,
Basis: basisRef, Path: "idontexist",
Path: "idontexist",
})) }))
require.NoError(err) require.NoError(err)
// Get exact // Get exact
{ {
resp, err := s.ProjectGet(&vagrant_plugin_sdk.Ref_Project{ resp, err := s.ProjectGet(&vagrant_plugin_sdk.Ref_Project{
ResourceId: resourceId, ResourceId: result.ResourceId,
}) })
require.NoError(err) require.NoError(err)
require.NotNil(resp) require.NotNil(resp)
require.Equal(resp.ResourceId, resourceId) require.Equal(resp.ResourceId, result.ResourceId)
} }
@ -70,16 +68,15 @@ func TestProject(t *testing.T) {
basisRef := testBasis(t, s) basisRef := testBasis(t, s)
// Set // Set
err := s.ProjectPut(serverptypes.TestProject(t, &vagrant_server.Project{ result, err := s.ProjectPut(serverptypes.TestProject(t, &vagrant_server.Project{
ResourceId: "AbCdE", Basis: basisRef,
Basis: basisRef, Path: "idontexist",
Path: "idontexist",
})) }))
require.NoError(err) require.NoError(err)
// Read // Read
resp, err := s.ProjectGet(&vagrant_plugin_sdk.Ref_Project{ resp, err := s.ProjectGet(&vagrant_plugin_sdk.Ref_Project{
ResourceId: "AbCdE", ResourceId: result.ResourceId,
}) })
require.NoError(err) require.NoError(err)
require.NotNil(resp) require.NotNil(resp)
@ -87,7 +84,7 @@ func TestProject(t *testing.T) {
// Delete // Delete
{ {
err := s.ProjectDelete(&vagrant_plugin_sdk.Ref_Project{ err := s.ProjectDelete(&vagrant_plugin_sdk.Ref_Project{
ResourceId: "AbCdE", ResourceId: result.ResourceId,
Basis: basisRef, Basis: basisRef,
}) })
require.NoError(err) require.NoError(err)
@ -96,7 +93,7 @@ func TestProject(t *testing.T) {
// Read // Read
{ {
_, err := s.ProjectGet(&vagrant_plugin_sdk.Ref_Project{ _, err := s.ProjectGet(&vagrant_plugin_sdk.Ref_Project{
ResourceId: "AbCdE", ResourceId: result.ResourceId,
}) })
require.Error(err) require.Error(err)
require.Equal(codes.NotFound, status.Code(err)) require.Equal(codes.NotFound, status.Code(err))

View File

@ -0,0 +1,172 @@
package state
import (
"database/sql"
"database/sql/driver"
"fmt"
"github.com/hashicorp/vagrant-plugin-sdk/internal-shared/dynamic"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
// ProtoValue stores a protobuf message in the database
// as a JSON field so it can be queried. Note that using
// this type can result in lossy storage depending on
// types in the message
type ProtoValue struct {
Message proto.Message
}
// User consumable data type name
func (p *ProtoValue) GormDataType() string {
return datatypes.JSON{}.GormDataType()
}
// Driver consumable data type name
func (p *ProtoValue) GormDBDataType(db *gorm.DB, field *schema.Field) string {
return datatypes.JSON{}.GormDBDataType(db, field)
}
// Unmarshals the store value back to original type
func (p *ProtoValue) Scan(value interface{}) error {
if value == nil {
return nil
}
s, ok := value.(string)
if !ok {
return fmt.Errorf("failed to unmarshal protobuf value, invalid type (%T)", value)
}
if s == "" {
return nil
}
v := []byte(s)
var m anypb.Any
err := protojson.Unmarshal(v, &m)
if err != nil {
return err
}
_, i, err := dynamic.DecodeAny(&m)
if err != nil {
return err
}
pm, ok := i.(proto.Message)
if !ok {
return fmt.Errorf("failed to set unmarshaled proto value, invalid type (%T)", i)
}
p.Message = pm
return nil
}
// Marshal the value for storage in the database
func (p *ProtoValue) Value() (driver.Value, error) {
if p == nil || p.Message == nil {
return nil, nil
}
a, err := dynamic.EncodeAny(p.Message)
if err != nil {
return nil, err
}
j, err := protojson.Marshal(a)
if err != nil {
return nil, err
}
return string(j), nil
}
// ProtoRaw stores a protobuf message in the database
// as raw bytes. Note that when using this type the
// contents of the protobuf message cannot be queried
type ProtoRaw struct {
Message proto.Message
}
// User consumable data type name
func (p *ProtoRaw) GormDataType() string {
return "bytes"
}
// Driver consumable data type name
func (p *ProtoRaw) GormDBDataType(db *gorm.DB, field *schema.Field) string {
return "BLOB"
}
// Unmarshals the store value back to original type
func (p *ProtoRaw) Scan(value interface{}) error {
if p == nil || value == nil {
return nil
}
s, ok := value.(string)
if !ok {
return fmt.Errorf("failed to unmarshal protobuf raw, invalid type (%T)", value)
}
if s == "" {
return nil
}
v := []byte(s)
var a anypb.Any
err := proto.Unmarshal(v, &a)
if err != nil {
return err
}
_, m, err := dynamic.DecodeAny(&a)
if err != nil {
return err
}
pm, ok := m.(proto.Message)
if !ok {
return fmt.Errorf("failed to set unmarshaled proto raw, invalid type (%T)", m)
}
p.Message = pm
return nil
}
// Marshal the value for storage in the database
func (p *ProtoRaw) Value() (driver.Value, error) {
if p == nil || p.Message == nil {
return nil, nil
}
m, err := dynamic.EncodeAny(p.Message)
if err != nil {
return nil, err
}
r, err := proto.Marshal(m)
if err != nil {
return nil, err
}
return string(r), nil
}
var (
_ sql.Scanner = (*ProtoValue)(nil)
_ driver.Valuer = (*ProtoValue)(nil)
_ schema.GormDataTypeInterface = (*ProtoValue)(nil)
_ migrator.GormDataTypeInterface = (*ProtoValue)(nil)
_ sql.Scanner = (*ProtoRaw)(nil)
_ driver.Valuer = (*ProtoRaw)(nil)
_ schema.GormDataTypeInterface = (*ProtoRaw)(nil)
_ migrator.GormDataTypeInterface = (*ProtoRaw)(nil)
)

View File

@ -1,102 +1,109 @@
package state package state
import ( import (
"github.com/hashicorp/go-memdb" "errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
"gorm.io/gorm"
) )
const ( type Runner struct {
runnerTableName = "runners" gorm.Model
runnerIdIndexName = "id"
)
func init() { Rid *string `gorm:"uniqueIndex;not null" mapstructure:"Id"`
schemas = append(schemas, runnerSchema) ByIdOnly bool
Components []*Component `gorm:"many2many:runner_components"`
} }
func runnerSchema() *memdb.TableSchema { func init() {
return &memdb.TableSchema{ models = append(models, &Runner{})
Name: runnerTableName, }
Indexes: map[string]*memdb.IndexSchema{
runnerIdIndexName: { func (r *Runner) ToProto() *vagrant_server.Runner {
Name: runnerIdIndexName, if r == nil {
AllowMissing: false, return nil
Unique: true, }
Indexer: &memdb.StringFieldIndex{
Field: "Id", components := make([]*vagrant_server.Component, len(r.Components))
Lowercase: true, for i, c := range r.Components {
}, components[i] = c.ToProto()
}, }
}, return &vagrant_server.Runner{
Id: *r.Rid,
ByIdOnly: r.ByIdOnly,
Components: components,
} }
} }
type runnerRecord struct { func (s *State) RunnerFromProto(p *vagrant_server.Runner) (*Runner, error) {
// The full Runner. All other fiels are derivatives of this. if p.Id == "" {
Runner *vagrant_server.Runner return nil, gorm.ErrRecordNotFound
}
// Id of the runner var runner Runner
Id string result := s.search().First(&runner, &Runner{Rid: &p.Id})
if result.Error != nil {
return nil, result.Error
}
return &runner, nil
}
func (s *State) RunnerById(id string) (*vagrant_server.Runner, error) {
r, err := s.RunnerFromProto(&vagrant_server.Runner{Id: id})
if err != nil {
return nil, lookupErrorToStatus("runner", err)
}
return r.ToProto(), nil
} }
func (s *State) RunnerCreate(r *vagrant_server.Runner) error { func (s *State) RunnerCreate(r *vagrant_server.Runner) error {
txn := s.inmem.Txn(true) runner, err := s.RunnerFromProto(r)
defer txn.Abort() if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return lookupErrorToStatus("runner", err)
// Create our runner
if err := txn.Insert(runnerTableName, newRunnerRecord(r)); err != nil {
return status.Errorf(codes.Aborted, err.Error())
} }
txn.Commit() if err != nil {
runner = &Runner{}
}
err = s.softDecode(r, runner)
if err != nil {
return saveErrorToStatus("runner", err)
}
result := s.db.Save(runner)
if result.Error != nil {
return saveErrorToStatus("runner", result.Error)
}
return nil return nil
} }
func (s *State) RunnerDelete(id string) error { func (s *State) RunnerDelete(id string) error {
txn := s.inmem.Txn(true) runner, err := s.RunnerFromProto(&vagrant_server.Runner{Id: id})
defer txn.Abort() if err != nil {
if _, err := txn.DeleteAll(runnerTableName, runnerIdIndexName, id); err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) {
return status.Errorf(codes.Aborted, err.Error()) return err
}
return nil
}
result := s.db.Delete(runner)
if result.Error != nil {
return deleteErrorToStatus("runner", result.Error)
} }
txn.Commit()
return nil return nil
} }
func (s *State) RunnerById(id string) (*vagrant_server.Runner, error) { // Returns if there are no registered runners
txn := s.inmem.Txn(false) func (s *State) runnerEmpty() (bool, error) {
raw, err := txn.First(runnerTableName, runnerIdIndexName, id) var c int64
txn.Abort() result := s.db.Model(&Runner{}).Count(&c)
if err != nil { if result.Error != nil {
return nil, err return false, result.Error
} }
if raw == nil { return c < 1, nil
return nil, status.Errorf(codes.NotFound, "runner ID not found")
}
return raw.(*runnerRecord).Runner, nil
}
// runnerEmpty returns true if there are no runners registered.
func (s *State) runnerEmpty(memTxn *memdb.Txn) (bool, error) {
iter, err := memTxn.LowerBound(runnerTableName, runnerIdIndexName, "")
if err != nil {
return false, err
}
return iter.Next() == nil, nil
}
// newRunnerRecord creates a runnerRecord from a runner.
func newRunnerRecord(r *vagrant_server.Runner) *runnerRecord {
rec := &runnerRecord{
Runner: r,
Id: r.Id,
}
return rec
} }

View File

@ -9,32 +9,65 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
func TestRunner_crud(t *testing.T) { func TestRunner(t *testing.T) {
require := require.New(t) t.Run("Basic CRUD", func(t *testing.T) {
require := require.New(t)
s := TestState(t) s := TestState(t)
defer s.Close() defer s.Close()
// Create an instance // Create an instance
rec := &vagrant_server.Runner{Id: "A"} rec := &vagrant_server.Runner{Id: "A"}
require.NoError(s.RunnerCreate(rec)) require.NoError(s.RunnerCreate(rec))
// We should be able to find it // We should be able to find it
found, err := s.RunnerById(rec.Id) found, err := s.RunnerById(rec.Id)
require.NoError(err) require.NoError(err)
require.Equal(rec, found) require.Equal(rec.Id, found.Id)
require.Empty(found.Components)
// Delete that instance // Delete that instance
require.NoError(s.RunnerDelete(rec.Id)) require.NoError(s.RunnerDelete(rec.Id))
// We should not find it // We should not find it
found, err = s.RunnerById(rec.Id) found, err = s.RunnerById(rec.Id)
require.Error(err) require.Error(err)
require.Nil(found) require.Nil(found)
require.Equal(codes.NotFound, status.Code(err)) require.Equal(codes.NotFound, status.Code(err))
// Delete again should be fine // Delete again should be fine
require.NoError(s.RunnerDelete(rec.Id)) require.NoError(s.RunnerDelete(rec.Id))
})
t.Run("CRUD with components", func(t *testing.T) {
require := require.New(t)
s := TestState(t)
defer s.Close()
components := []*vagrant_server.Component{
{
Name: "command",
Type: vagrant_server.Component_COMMAND,
},
{
Name: "communicator",
Type: vagrant_server.Component_COMMUNICATOR,
},
}
// Create an instance
rec := &vagrant_server.Runner{
Id: "A",
Components: components,
}
require.NoError(s.RunnerCreate(rec))
found, err := s.RunnerById(rec.Id)
require.NoError(err)
require.Equal(rec.Id, found.Id)
require.Equal(rec.Components, found.Components)
})
} }
func TestRunnerById_notFound(t *testing.T) { func TestRunnerById_notFound(t *testing.T) {

View File

@ -4,137 +4,117 @@ package state
import ( import (
"crypto/rand" "crypto/rand"
"errors"
"fmt" "fmt"
"reflect"
"sync" "sync"
"time"
"github.com/go-ozzo/ozzo-validation/v4"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb" "github.com/hashicorp/go-memdb"
"github.com/oklog/ulid/v2" "google.golang.org/grpc/codes"
bolt "go.etcd.io/bbolt" "google.golang.org/grpc/status"
"gorm.io/gorm"
"gorm.io/gorm/clause"
) )
// The global variables below can be set by init() functions of other
// files in this package to setup the database state for the server.
var ( var (
// schemas is used to register schemas with the state store. Other files should // schemas is used to register schemas within the state store. Other
// use the init() callback to append to this. // files should use the init() callback to append to this.
schemas []schemaFn schemas []schemaFn
// dbBuckets is the list of buckets that should be created by dbInit. // All the data persisted models defined. Other files should
// Various components should use init() funcs to append to this. // use the init() callback to append to this list.
dbBuckets [][]byte models = []interface{}{}
// dbIndexers is the list of functions to call to initialize the // dbIndexers is the list of functions to call to initialize the
// in-memory indexes from the persisted db. // in-memory indexes from the persisted db.
dbIndexers []indexFn dbIndexers []indexFn
entropy = rand.Reader entropy = rand.Reader
// Error returned when proto value passed is nil
ErrEmptyProtoArgument = errors.New("no proto value provided")
// Error returned when a proto reference does not include its parent
ErrMissingProtoParent = errors.New("proto reference does not include parent")
) )
// indexFn is the function type for initializing in-memory indexes from
// persisted data. This is usually specified as a method handle to a
// *State method.
//
// The bolt.Tx is read-only while the memdb.Txn is a write transaction.
type indexFn func(*State, *memdb.Txn) error
// State is the primary API for state mutation for the server. // State is the primary API for state mutation for the server.
type State struct { type State struct {
// Connection to our database
db *gorm.DB
// inmem is our in-memory database that stores ephemeral data in an // inmem is our in-memory database that stores ephemeral data in an
// easier-to-query way. Some of this data may be periodically persisted // easier-to-query way. Some of this data may be periodically persisted
// but most of this data is meant to be lost when the process restarts. // but most of this data is meant to be lost when the process restarts.
inmem *memdb.MemDB inmem *memdb.MemDB
// db is our persisted on-disk database. This stores the bulk of data
// and supports a transactional model for safe concurrent access.
// inmem is used alongside db to store in-memory indexing information
// for more efficient lookups into db. This index is built online at
// boot.
db *bolt.DB
// hmacKeyNotEmpty is flipped to 1 when an hmac entry is set. This is
// used to determine if we're in a bootstrap state and can create a
// bootstrap token.
hmacKeyNotEmpty uint32
// indexers is used to track whether an indexer was called. This is // indexers is used to track whether an indexer was called. This is
// initialized during New and set to nil at the end of New. // initialized during New and set to nil at the end of New.
indexers map[uintptr]struct{} indexers map[uintptr]struct{}
// Where to log to
log hclog.Logger
// indexedJobs indicates how many job records we are tracking in memory // indexedJobs indicates how many job records we are tracking in memory
indexedJobs int indexedJobs int
// Used to track prune records // Used to track prune records
pruneMu sync.Mutex pruneMu sync.Mutex
// Where to log to
log hclog.Logger
} }
// New initializes a new State store. // New initializes a new State store.
func New(log hclog.Logger, db *bolt.DB) (*State, error) { func New(log hclog.Logger, db *gorm.DB) (*State, error) {
// Restore DB if necessary log = log.Named("state")
db, err := finalizeRestore(log, db) err := db.AutoMigrate(models...)
if err != nil { if err != nil {
log.Trace("failure encountered during finalize restore", "error", err) log.Trace("failure encountered during auto migration",
"error", err,
)
return nil, err return nil, err
} }
// Create the in-memory DB. // Create the in-memory DB
inmem, err := memdb.NewMemDB(stateStoreSchema()) inmem, err := memdb.NewMemDB(stateStoreSchema())
if err != nil { if err != nil {
log.Trace("failed to setup in-memory database", "error", err) log.Trace("failed to setup in-memory database", "error", err)
return nil, fmt.Errorf("Failed setting up state store: %s", err)
}
// Initialize and validate our on-disk format.
if err := dbInit(db); err != nil {
log.Error("failed to initialize and validate on-disk format", "error", err)
return nil, err return nil, err
} }
s := &State{inmem: inmem, db: db, log: log} s := &State{
db: db,
inmem: inmem,
log: log,
}
// Initialize our set that'll track what memdb indexers we call. // Initialize the in-memory indicies
// When we're done we always clear this out since it is never used memTxn := inmem.Txn(true)
// again.
s.indexers = make(map[uintptr]struct{})
defer func() { s.indexers = nil }()
// Initialize our in-memory indexes
memTxn := s.inmem.Txn(true)
defer memTxn.Abort() defer memTxn.Abort()
err = s.db.View(func(dbTxn *bolt.Tx) error { for _, indexer := range dbIndexers {
for _, indexer := range dbIndexers { if err := indexer(s, memTxn); err != nil {
// TODO: this should use callIndexer but it's broken as it prevents the multiple op indexers return nil, err
// from properly running.
if err := indexer(s, dbTxn, memTxn); err != nil {
return err
}
} }
return nil
})
if err != nil {
log.Error("failed to generate in memory index", "error", err)
return nil, err
} }
memTxn.Commit() memTxn.Commit()
return s, nil return s, nil
} }
// callIndexer calls the specified indexer exactly once. If it has been called
// before this returns no error. This must not be called concurrently. This
// can be used from indexers to ensure other data is indexed first.
func (s *State) callIndexer(fn indexFn, dbTxn *bolt.Tx, memTxn *memdb.Txn) error {
fnptr := reflect.ValueOf(fn).Pointer()
if _, ok := s.indexers[fnptr]; ok {
return nil
}
s.indexers[fnptr] = struct{}{}
return fn(s, dbTxn, memTxn)
}
// Close should be called to gracefully close any resources. // Close should be called to gracefully close any resources.
func (s *State) Close() error { func (s *State) Close() error {
return s.db.Close() db, err := s.db.DB()
if err != nil {
return err
}
return db.Close()
} }
// Prune should be called in a on a regular interval to allow State // Prune should be called in a on a regular interval to allow State
@ -180,17 +160,66 @@ func stateStoreSchema() *memdb.DBSchema {
return db return db
} }
// indexFn is the function type for initializing in-memory indexes from // Provides db for searching
// persisted data. This is usually specified as a method handle to a // NOTE: In most cases this should be used instead of accessing `db`
// *State method. // directly when searching for values to ensure all associations are
// // fully loaded in the results.
// The bolt.Tx is read-only while the memdb.Txn is a write transaction. func (s *State) search() *gorm.DB {
type indexFn func(*State, *bolt.Tx, *memdb.Txn) error return s.db.Preload(clause.Associations)
}
func (*State) newResourceId() (string, error) {
id, err := ulid.New(ulid.Timestamp(time.Now()), entropy) // Convert error to a GRPC status error when dealing with lookups
if err != nil { func lookupErrorToStatus(
return "", err typeName string, // thing trying to be found (basis, project, etc)
} err error, // error to convert
return id.String(), nil ) error {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errorToStatus(fmt.Errorf("failed to locate %s (%w)", typeName, err))
}
if errors.Is(err, ErrEmptyProtoArgument) || errors.Is(err, ErrMissingProtoParent) {
return errorToStatus(fmt.Errorf("cannot lookup %s (%w)", typeName, err))
}
return errorToStatus(fmt.Errorf("unexpected error encountered during %s lookup (%w)", typeName, err))
}
// Convert error to GRPC status error when failing to save
func saveErrorToStatus(
typeName string, // thing trying to be saved
err error, // error to convert
) error {
var vErr validation.Error
if errors.Is(err, ErrEmptyProtoArgument) ||
errors.Is(err, ErrMissingProtoParent) ||
errors.As(err, &vErr) {
return errorToStatus(fmt.Errorf("cannot save %s (%w)", typeName, err))
}
return errorToStatus(fmt.Errorf("unexpected error encountered while saving %s (%w)", typeName, err))
}
// Convert error to GRPC status error when failing to delete
func deleteErrorToStatus(
typeName string, // thing trying to be deleted
err error, // error to convert
) error {
return errorToStatus(fmt.Errorf("unexpected error encountered while deleting %s (%w)", typeName, err))
}
// Convert error to a GRPC status error
func errorToStatus(
err error, // error to convert
) error {
if errors.Is(err, gorm.ErrRecordNotFound) {
return status.Error(codes.NotFound, err.Error())
}
var vErr validation.Error
if errors.Is(err, ErrEmptyProtoArgument) ||
errors.Is(err, ErrMissingProtoParent) ||
errors.As(err, &vErr) {
return status.Error(codes.FailedPrecondition, err.Error())
}
return status.Error(codes.Internal, err.Error())
} }

View File

@ -1,402 +1,354 @@
package state package state
import ( import (
"github.com/google/uuid" "errors"
"github.com/hashicorp/go-memdb" "fmt"
bolt "go.etcd.io/bbolt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"github.com/go-ozzo/ozzo-validation/v4"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk" "github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
serverptypes "github.com/hashicorp/vagrant/internal/server/ptypes" "gorm.io/gorm"
) )
var targetBucket = []byte("target")
func init() { func init() {
dbBuckets = append(dbBuckets, targetBucket) models = append(models, &Target{})
dbIndexers = append(dbIndexers, (*State).targetIndexInit)
schemas = append(schemas, targetIndexSchema)
} }
func (s *State) TargetFind(m *vagrant_server.Target) (*vagrant_server.Target, error) { type Target struct {
memTxn := s.inmem.Txn(false) gorm.Model
defer memTxn.Abort()
var result *vagrant_server.Target Configuration *ProtoRaw
err := s.db.View(func(dbTxn *bolt.Tx) error { Jobs []*InternalJob `gorm:"polymorphic:Scope;" mapstructure:"-"`
var err error Metadata MetadataSet
result, err = s.targetFind(dbTxn, memTxn, m) Name *string `gorm:"uniqueIndex:idx_pname;not null"`
Parent *Target `gorm:"foreignkey:ID"`
ParentID uint `mapstructure:"-"`
Project *Project
ProjectID *uint `gorm:"uniqueIndex:idx_pname;not null" mapstructure:"-"`
Provider *string
Record *ProtoRaw
ResourceId *string `gorm:"<-:create;uniqueIndex;not null"`
State vagrant_server.Operation_PhysicalState
Subtargets []*Target `gorm:"foreignkey:ParentID"`
Uuid *string `gorm:"uniqueIndex"`
l hclog.Logger
}
func (t *Target) scope() interface{} {
return t
}
// Set a public ID on the target before creating
func (t *Target) BeforeSave(tx *gorm.DB) error {
if t.ResourceId == nil {
if err := t.setId(); err != nil {
return err
}
}
if err := t.validate(tx); err != nil {
return err return err
})
return result, err
}
func (s *State) TargetPut(target *vagrant_server.Target) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) error {
return s.targetPut(dbTxn, memTxn, target)
})
if err == nil {
memTxn.Commit()
}
return err
}
func (s *State) TargetDelete(ref *vagrant_plugin_sdk.Ref_Target) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) error {
return s.targetDelete(dbTxn, memTxn, ref)
})
if err == nil {
memTxn.Commit()
} }
return err return nil
} }
func (s *State) TargetGet(ref *vagrant_plugin_sdk.Ref_Target) (*vagrant_server.Target, error) { func (t *Target) validate(tx *gorm.DB) error {
memTxn := s.inmem.Txn(false) err := validation.ValidateStruct(t,
defer memTxn.Abort() validation.Field(&t.Name,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Target{}).
Where(&Target{Name: t.Name, ProjectID: t.ProjectID}).
Not(&Target{Model: gorm.Model{ID: t.ID}}),
),
),
),
validation.Field(&t.ResourceId,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Target{}).
Where(&Target{ResourceId: t.ResourceId}).
Not(&Target{Model: gorm.Model{ID: t.ID}}),
),
),
),
validation.Field(&t.Uuid,
validation.When(t.Uuid != nil,
validation.By(
checkUnique(
tx.Model(&Target{}).
Where(&Target{Uuid: t.Uuid}).
Not(&Target{Model: gorm.Model{ID: t.ID}}),
),
),
),
),
// validation.Field(&t.ProjectID, validation.Required), TODO(spox): why are these empty?
)
var result *vagrant_server.Target if err != nil {
err := s.db.View(func(dbTxn *bolt.Tx) error {
var err error
result, err = s.targetGet(dbTxn, memTxn, ref)
return err return err
}) }
return result, err return nil
} }
func (s *State) TargetList() ([]*vagrant_plugin_sdk.Ref_Target, error) { func (t *Target) setId() error {
memTxn := s.inmem.Txn(false) id, err := server.Id()
defer memTxn.Abort() if err != nil {
return err
}
t.ResourceId = &id
return s.targetList(memTxn) return nil
} }
func (s *State) targetFind( // Convert target to reference protobuf message
dbTxn *bolt.Tx, func (t *Target) ToProtoRef() *vagrant_plugin_sdk.Ref_Target {
memTxn *memdb.Txn, if t == nil {
m *vagrant_server.Target, return nil
) (*vagrant_server.Target, error) {
var match *targetIndexRecord
req := s.newTargetIndexRecord(m)
// Start with the resource id first
if req.Id != "" {
if raw, err := memTxn.First(
targetIndexTableName,
targetIndexIdIndexName,
req.Id,
); raw != nil && err == nil {
match = raw.(*targetIndexRecord)
}
}
// Try the name + project next
if match == nil && req.Name != "" {
// Match the name first
raw, err := memTxn.Get(
targetIndexTableName,
targetIndexNameIndexName,
req.Name,
)
if err != nil {
return nil, err
}
// Check for matching project next
if req.ProjectId != "" {
for e := raw.Next(); e != nil; e = raw.Next() {
targetIndexEntry := e.(*targetIndexRecord)
if targetIndexEntry.ProjectId == req.ProjectId {
match = targetIndexEntry
break
}
}
} else {
e := raw.Next()
if e != nil {
match = e.(*targetIndexRecord)
}
}
}
// Finally try the uuid
if match == nil && req.Uuid != "" {
if raw, err := memTxn.First(
targetIndexTableName,
targetIndexUuidName,
req.Uuid,
); raw != nil && err == nil {
match = raw.(*targetIndexRecord)
}
} }
if match == nil { var ref vagrant_plugin_sdk.Ref_Target
return nil, status.Errorf(codes.NotFound, "record not found for Target (name: %s resource_id: %s)", m.Name, m.ResourceId)
err := decode(t, &ref)
if err != nil {
panic("failed to decode target to ref: " + err.Error())
} }
return s.targetGet(dbTxn, memTxn, &vagrant_plugin_sdk.Ref_Target{ return &ref
ResourceId: match.Id,
})
} }
func (s *State) targetList( // Convert target to protobuf message
memTxn *memdb.Txn, func (t *Target) ToProto() *vagrant_server.Target {
) ([]*vagrant_plugin_sdk.Ref_Target, error) { if t == nil {
iter, err := memTxn.Get(targetIndexTableName, targetIndexIdIndexName+"_prefix", "") return nil
}
var target vagrant_server.Target
err := decode(t, &target)
if err != nil {
panic("failed to decode target: " + err.Error())
}
return &target
}
// Load a Target from reference protobuf message
func (s *State) TargetFromProtoRef(
ref *vagrant_plugin_sdk.Ref_Target,
) (*Target, error) {
if ref == nil {
return nil, ErrEmptyProtoArgument
}
if ref.ResourceId == "" {
return nil, gorm.ErrRecordNotFound
}
var target Target
result := s.search().Preload("Project.Basis").First(&target,
&Target{ResourceId: &ref.ResourceId},
)
if result.Error != nil {
return nil, result.Error
}
return &target, nil
}
func (s *State) TargetFromProtoRefFuzzy(
ref *vagrant_plugin_sdk.Ref_Target,
) (*Target, error) {
target, err := s.TargetFromProtoRef(ref)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
if ref.Project == nil {
return nil, ErrMissingProtoParent
}
if ref.Name == "" {
return nil, gorm.ErrRecordNotFound
}
target = &Target{}
result := s.search().
Joins("Project", &Project{ResourceId: &ref.Project.ResourceId}).
Preload("Project.Basis").
First(target, &Target{Name: &ref.Name})
if result.Error != nil {
return nil, result.Error
}
return target, nil
}
// Load a Target from protobuf message
func (s *State) TargetFromProto(
t *vagrant_server.Target,
) (*Target, error) {
target, err := s.TargetFromProtoRef(
&vagrant_plugin_sdk.Ref_Target{
ResourceId: t.ResourceId,
},
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var result []*vagrant_plugin_sdk.Ref_Target return target, nil
for {
next := iter.Next()
if next == nil {
break
}
result = append(result, &vagrant_plugin_sdk.Ref_Target{
ResourceId: next.(*targetIndexRecord).Id,
Name: next.(*targetIndexRecord).Name,
Project: &vagrant_plugin_sdk.Ref_Project{
ResourceId: next.(*targetIndexRecord).ProjectId,
},
})
}
return result, nil
} }
func (s *State) targetPut( func (s *State) TargetFromProtoFuzzy(
dbTxn *bolt.Tx, t *vagrant_server.Target,
memTxn *memdb.Txn, ) (*Target, error) {
value *vagrant_server.Target, target, err := s.TargetFromProto(t)
) (err error) { if err == nil {
s.log.Trace("storing target", "target", value, "project", return target, nil
value.GetProject(), "basis", value.GetProject().GetBasis())
p, err := s.projectGet(dbTxn, memTxn, value.Project)
if err != nil {
s.log.Error("failed to locate project for target", "target", value,
"project", p, "error", err)
return
} }
if value.ResourceId == "" { if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
// If no resource id is provided, try to find the target based on the name and project return nil, err
foundTarget, erro := s.targetFind(dbTxn, memTxn, value)
// If an invalid return code is returned from find then an error occured
if _, ok := status.FromError(erro); !ok {
return erro
}
if foundTarget != nil {
// Make sure the config doesn't get merged - we want the config to overwrite the old config
finalConfig := proto.Clone(value.Configuration)
// Merge found target with provided target
proto.Merge(value, foundTarget)
value.ResourceId = foundTarget.ResourceId
value.Uuid = foundTarget.Uuid
value.Configuration = finalConfig.(*vagrant_plugin_sdk.Args_ConfigData)
} else {
s.log.Trace("target has no resource id and could not find matching target, assuming new target",
"target", value)
if value.ResourceId, err = s.newResourceId(); err != nil {
s.log.Error("failed to create resource id for target", "target", value,
"error", err)
return
}
}
if value.Uuid == "" {
s.log.Trace("target has no uuid assigned, assigning...", "target", value)
uID, err := uuid.NewUUID()
if err != nil {
return err
}
value.Uuid = uID.String()
}
} }
s.log.Trace("storing target to db", "target", value) if t.Project == nil {
id := s.targetId(value) return nil, ErrMissingProtoParent
b := dbTxn.Bucket(targetBucket)
if err = dbPut(b, id, value); err != nil {
s.log.Error("failed to store target in db", "target", value, "error", err)
return
} }
s.log.Trace("indexing target", "target", value) target = &Target{}
if err = s.targetIndexSet(memTxn, id, value); err != nil { query := &Target{Name: &t.Name}
s.log.Error("failed to index target", "target", value, "error", err) tx := s.db.
return Preload("Project",
s.db.Where(
&Project{ResourceId: &t.Project.ResourceId},
),
)
if t.Name != "" {
query.Name = &t.Name
}
if t.Uuid != "" {
query.Uuid = &t.Uuid
tx = tx.Or("uuid LIKE ?", fmt.Sprintf("%%%s%%", t.Uuid))
} }
s.log.Trace("adding target to project", "target", value, "project", p) result := s.search().Joins("Project").
pp := serverptypes.Project{Project: p} Preload("Project.Basis").
if pp.AddTarget(value) { Where("Project.resource_id = ?", t.Project.ResourceId).
s.log.Trace("target added to project, updating project", "project", p) First(target, query)
if err = s.projectPut(dbTxn, memTxn, p); err != nil { if result.Error != nil {
s.log.Error("failed to update project", "project", p, "error", err) return nil, result.Error
return
}
} else {
s.log.Trace("target already exists in project", "target", value, "project", p)
} }
return return target, nil
} }
func (s *State) targetGet( // Get a target record using a reference protobuf message
dbTxn *bolt.Tx, func (s *State) TargetGet(
memTxn *memdb.Txn,
ref *vagrant_plugin_sdk.Ref_Target, ref *vagrant_plugin_sdk.Ref_Target,
) (*vagrant_server.Target, error) { ) (*vagrant_server.Target, error) {
var result vagrant_server.Target t, err := s.TargetFromProtoRef(ref)
b := dbTxn.Bucket(targetBucket)
return &result, dbGet(b, s.targetIdByRef(ref), &result)
}
func (s *State) targetDelete(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
ref *vagrant_plugin_sdk.Ref_Target,
) (err error) {
p, err := s.projectGet(dbTxn, memTxn, &vagrant_plugin_sdk.Ref_Project{ResourceId: ref.Project.ResourceId})
if err != nil { if err != nil {
return return nil, lookupErrorToStatus("target", err)
} }
if err = dbTxn.Bucket(targetBucket).Delete(s.targetIdByRef(ref)); err != nil { return t.ToProto(), nil
return
}
if err = memTxn.Delete(targetIndexTableName, s.newTargetIndexRecordByRef(ref)); err != nil {
return
}
pp := serverptypes.Project{Project: p}
if pp.DeleteTargetRef(ref) {
if err = s.projectPut(dbTxn, memTxn, pp.Project); err != nil {
return
}
}
return
} }
func (s *State) targetIndexSet(txn *memdb.Txn, id []byte, value *vagrant_server.Target) error { // List all target records
return txn.Insert(targetIndexTableName, s.newTargetIndexRecord(value)) func (s *State) TargetList() ([]*vagrant_plugin_sdk.Ref_Target, error) {
var targets []Target
result := s.search().Find(&targets)
if result.Error != nil {
return nil, lookupErrorToStatus("targets", result.Error)
}
trefs := make([]*vagrant_plugin_sdk.Ref_Target, len(targets))
for i, t := range targets {
trefs[i] = t.ToProtoRef()
}
return trefs, nil
} }
func (s *State) targetIndexInit(dbTxn *bolt.Tx, memTxn *memdb.Txn) error { // Delete a target by reference protobuf message
bucket := dbTxn.Bucket(targetBucket) func (s *State) TargetDelete(
return bucket.ForEach(func(k, v []byte) error { t *vagrant_plugin_sdk.Ref_Target,
var value vagrant_server.Target ) error {
if err := proto.Unmarshal(v, &value); err != nil { target, err := s.TargetFromProtoRef(t)
return err if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
}
if err := s.targetIndexSet(memTxn, k, &value); err != nil {
return err
}
return nil return nil
})
}
func targetIndexSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: targetIndexTableName,
Indexes: map[string]*memdb.IndexSchema{
targetIndexIdIndexName: {
Name: targetIndexIdIndexName,
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Id",
Lowercase: false,
},
},
targetIndexNameIndexName: {
Name: targetIndexNameIndexName,
AllowMissing: false,
Unique: false,
Indexer: &memdb.StringFieldIndex{
Field: "Name",
Lowercase: true,
},
},
targetIndexProjectIndexName: {
Name: targetIndexProjectIndexName,
AllowMissing: false,
Unique: false,
Indexer: &memdb.StringFieldIndex{
Field: "ProjectId",
Lowercase: true,
},
},
targetIndexUuidName: {
Name: targetIndexUuidName,
AllowMissing: true,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Uuid",
Lowercase: true,
},
},
},
} }
if err != nil {
return lookupErrorToStatus("target", err)
}
result := s.db.Delete(target)
if result.Error != nil {
return deleteErrorToStatus("target", result.Error)
}
return nil
} }
const ( // Store a Target
targetIndexIdIndexName = "id" func (s *State) TargetPut(
targetIndexNameIndexName = "name" t *vagrant_server.Target,
targetIndexProjectIndexName = "project" ) (*vagrant_server.Target, error) {
targetIndexUuidName = "uuid" target, err := s.TargetFromProto(t)
targetIndexTableName = "target-index" if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, lookupErrorToStatus("target", err)
}
// Make sure we don't have a nil
if err != nil {
target = &Target{}
}
s.log.Info("pre-decode our target project", "project", target.Project)
err = s.softDecode(t, target)
if err != nil {
return nil, saveErrorToStatus("target", err)
}
s.log.Info("post-decode our target project", "project", target.Project)
if target.Project == nil {
panic("stop")
}
result := s.db.Save(target)
s.log.Info("after save target project status", "project", target.Project, "error", err)
if result.Error != nil {
return nil, saveErrorToStatus("target", result.Error)
}
return target.ToProto(), nil
}
// Find a Target
func (s *State) TargetFind(
t *vagrant_server.Target,
) (*vagrant_server.Target, error) {
target, err := s.TargetFromProtoFuzzy(t)
if err != nil {
return nil, lookupErrorToStatus("target", err)
}
return target.ToProto(), nil
}
var (
_ scope = (*Target)(nil)
) )
type targetIndexRecord struct {
Id string // Resource ID
Name string // Target Name
ProjectId string // Project Resource ID
Uuid string // Target UUID
}
func (s *State) newTargetIndexRecord(m *vagrant_server.Target) *targetIndexRecord {
var projectResourceId string
if m.Project != nil {
projectResourceId = m.Project.ResourceId
}
i := &targetIndexRecord{
Id: m.ResourceId,
Name: m.Name,
ProjectId: projectResourceId,
Uuid: m.Uuid,
}
return i
}
func (s *State) newTargetIndexRecordByRef(ref *vagrant_plugin_sdk.Ref_Target) *targetIndexRecord {
var projectResourceId string
if ref.Project != nil {
projectResourceId = ref.Project.ResourceId
}
return &targetIndexRecord{
Id: ref.ResourceId,
Name: ref.Name,
ProjectId: projectResourceId,
}
}
func (s *State) targetId(m *vagrant_server.Target) []byte {
return []byte(m.ResourceId)
}
func (s *State) targetIdByRef(m *vagrant_plugin_sdk.Ref_Target) []byte {
return []byte(m.ResourceId)
}

View File

@ -11,7 +11,6 @@ import (
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk" "github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
serverptypes "github.com/hashicorp/vagrant/internal/server/ptypes"
) )
func TestTarget(t *testing.T) { func TestTarget(t *testing.T) {
@ -36,13 +35,11 @@ func TestTarget(t *testing.T) {
defer s.Close() defer s.Close()
projectRef := testProject(t, s) projectRef := testProject(t, s)
resourceId := "AbCdE"
// Set // Set
err := s.TargetPut(serverptypes.TestTarget(t, &vagrant_server.Target{ result, err := s.TargetPut(&vagrant_server.Target{
ResourceId: resourceId, Project: projectRef,
Project: projectRef, Name: "test",
Name: "test", })
}))
require.NoError(err) require.NoError(err)
// Ensure there is one entry // Ensure there is one entry
@ -51,12 +48,14 @@ func TestTarget(t *testing.T) {
require.Len(resp, 1) require.Len(resp, 1)
// Try to insert duplicate entry // Try to insert duplicate entry
err = s.TargetPut(serverptypes.TestTarget(t, &vagrant_server.Target{ doubleResult, err := s.TargetPut(&vagrant_server.Target{
ResourceId: resourceId, ResourceId: result.ResourceId,
Project: projectRef, Project: projectRef,
Name: "test", Name: "test",
})) })
require.NoError(err) require.NoError(err)
require.Equal(doubleResult.ResourceId, result.ResourceId)
require.Equal(doubleResult.Project, result.Project)
// Ensure there is still one entry // Ensure there is still one entry
resp, err = s.TargetList() resp, err = s.TargetList()
@ -64,11 +63,11 @@ func TestTarget(t *testing.T) {
require.Len(resp, 1) require.Len(resp, 1)
// Try to insert duplicate entry by just name and project // Try to insert duplicate entry by just name and project
err = s.TargetPut(serverptypes.TestTarget(t, &vagrant_server.Target{ _, err = s.TargetPut(&vagrant_server.Target{
Project: projectRef, Project: projectRef,
Name: "test", Name: "test",
})) })
require.NoError(err) require.Error(err)
// Ensure there is still one entry // Ensure there is still one entry
resp, err = s.TargetList() resp, err = s.TargetList()
@ -78,9 +77,8 @@ func TestTarget(t *testing.T) {
// Try to insert duplicate config // Try to insert duplicate config
key, _ := anypb.New(&wrapperspb.StringValue{Value: "vm"}) key, _ := anypb.New(&wrapperspb.StringValue{Value: "vm"})
value, _ := anypb.New(&wrapperspb.StringValue{Value: "value"}) value, _ := anypb.New(&wrapperspb.StringValue{Value: "value"})
err = s.TargetPut(serverptypes.TestTarget(t, &vagrant_server.Target{ _, err = s.TargetPut(&vagrant_server.Target{
Project: projectRef, ResourceId: result.ResourceId,
Name: "test",
Configuration: &vagrant_plugin_sdk.Args_ConfigData{ Configuration: &vagrant_plugin_sdk.Args_ConfigData{
Data: &vagrant_plugin_sdk.Args_Hash{ Data: &vagrant_plugin_sdk.Args_Hash{
Entries: []*vagrant_plugin_sdk.Args_HashEntry{ Entries: []*vagrant_plugin_sdk.Args_HashEntry{
@ -91,11 +89,10 @@ func TestTarget(t *testing.T) {
}, },
}, },
}, },
})) })
require.NoError(err) require.NoError(err)
err = s.TargetPut(serverptypes.TestTarget(t, &vagrant_server.Target{ _, err = s.TargetPut(&vagrant_server.Target{
Project: projectRef, ResourceId: result.ResourceId,
Name: "test",
Configuration: &vagrant_plugin_sdk.Args_ConfigData{ Configuration: &vagrant_plugin_sdk.Args_ConfigData{
Data: &vagrant_plugin_sdk.Args_Hash{ Data: &vagrant_plugin_sdk.Args_Hash{
Entries: []*vagrant_plugin_sdk.Args_HashEntry{ Entries: []*vagrant_plugin_sdk.Args_HashEntry{
@ -106,7 +103,7 @@ func TestTarget(t *testing.T) {
}, },
}, },
}, },
})) })
require.NoError(err) require.NoError(err)
// Ensure there is still one entry // Ensure there is still one entry
@ -115,9 +112,11 @@ func TestTarget(t *testing.T) {
require.Len(resp, 1) require.Len(resp, 1)
// Ensure the config did not merge // Ensure the config did not merge
targetResp, err := s.TargetGet(&vagrant_plugin_sdk.Ref_Target{ targetResp, err := s.TargetGet(&vagrant_plugin_sdk.Ref_Target{
ResourceId: resourceId, ResourceId: result.ResourceId,
}) })
require.NoError(err) require.NoError(err)
require.NotNil(targetResp.Configuration)
require.NotNil(targetResp.Configuration.Data)
require.Len(targetResp.Configuration.Data.Entries, 1) require.Len(targetResp.Configuration.Data.Entries, 1)
vmAny := targetResp.Configuration.Data.Entries[0].Value vmAny := targetResp.Configuration.Data.Entries[0].Value
vmString := wrapperspb.StringValue{} vmString := wrapperspb.StringValue{}
@ -127,11 +126,11 @@ func TestTarget(t *testing.T) {
// Get exact // Get exact
{ {
resp, err := s.TargetGet(&vagrant_plugin_sdk.Ref_Target{ resp, err := s.TargetGet(&vagrant_plugin_sdk.Ref_Target{
ResourceId: resourceId, ResourceId: result.ResourceId,
}) })
require.NoError(err) require.NoError(err)
require.NotNil(resp) require.NotNil(resp)
require.Equal(resp.ResourceId, resourceId) require.Equal(resp.ResourceId, result.ResourceId)
} }
@ -150,18 +149,16 @@ func TestTarget(t *testing.T) {
defer s.Close() defer s.Close()
projectRef := testProject(t, s) projectRef := testProject(t, s)
resourceId := "AbCdE"
// Set // Set
err := s.TargetPut(serverptypes.TestTarget(t, &vagrant_server.Target{ result, err := s.TargetPut(&vagrant_server.Target{
ResourceId: resourceId, Project: projectRef,
Project: projectRef, Name: "test",
Name: "test", })
}))
require.NoError(err) require.NoError(err)
// Read // Read
resp, err := s.TargetGet(&vagrant_plugin_sdk.Ref_Target{ resp, err := s.TargetGet(&vagrant_plugin_sdk.Ref_Target{
ResourceId: resourceId, ResourceId: result.ResourceId,
}) })
require.NoError(err) require.NoError(err)
require.NotNil(resp) require.NotNil(resp)
@ -169,7 +166,7 @@ func TestTarget(t *testing.T) {
// Delete // Delete
{ {
err := s.TargetDelete(&vagrant_plugin_sdk.Ref_Target{ err := s.TargetDelete(&vagrant_plugin_sdk.Ref_Target{
ResourceId: resourceId, ResourceId: result.ResourceId,
Project: projectRef, Project: projectRef,
}) })
require.NoError(err) require.NoError(err)
@ -178,7 +175,7 @@ func TestTarget(t *testing.T) {
// Read // Read
{ {
_, err := s.TargetGet(&vagrant_plugin_sdk.Ref_Target{ _, err := s.TargetGet(&vagrant_plugin_sdk.Ref_Target{
ResourceId: resourceId, ResourceId: result.ResourceId,
}) })
require.Error(err) require.Error(err)
require.Equal(codes.NotFound, status.Code(err)) require.Equal(codes.NotFound, status.Code(err))
@ -199,33 +196,30 @@ func TestTarget(t *testing.T) {
defer s.Close() defer s.Close()
projectRef := testProject(t, s) projectRef := testProject(t, s)
resourceId := "AbCdE"
// Set // Set
err := s.TargetPut(serverptypes.TestTarget(t, &vagrant_server.Target{ result, err := s.TargetPut(&vagrant_server.Target{
ResourceId: resourceId, Project: projectRef,
Project: projectRef, Name: "test",
Name: "test", })
}))
require.NoError(err) require.NoError(err)
// Find by resource id // Find by resource id
{ {
resp, err := s.TargetFind(&vagrant_server.Target{ resp, err := s.TargetFind(&vagrant_server.Target{
ResourceId: resourceId, ResourceId: result.ResourceId,
}) })
require.NoError(err) require.NoError(err)
require.NotNil(resp) require.NotNil(resp)
require.Equal(resp.ResourceId, resourceId) require.Equal(resp.ResourceId, result.ResourceId)
} }
// Find by resource name // Find by resource name without project
{ {
resp, err := s.TargetFind(&vagrant_server.Target{ resp, err := s.TargetFind(&vagrant_server.Target{
Name: "test", Name: "test",
}) })
require.NoError(err) require.Error(err)
require.NotNil(resp) require.Nil(resp)
require.Equal(resp.ResourceId, resourceId)
} }
// Find by resource name+project // Find by resource name+project
@ -235,7 +229,7 @@ func TestTarget(t *testing.T) {
}) })
require.NoError(err) require.NoError(err)
require.NotNil(resp) require.NotNil(resp)
require.Equal(resp.ResourceId, resourceId) require.Equal(resp.ResourceId, result.ResourceId)
} }
// Don't find nonexistent project // Don't find nonexistent project
@ -243,8 +237,8 @@ func TestTarget(t *testing.T) {
resp, err := s.TargetFind(&vagrant_server.Target{ resp, err := s.TargetFind(&vagrant_server.Target{
Name: "test", Project: &vagrant_plugin_sdk.Ref_Project{ResourceId: "idontexist"}, Name: "test", Project: &vagrant_plugin_sdk.Ref_Project{ResourceId: "idontexist"},
}) })
require.Error(err)
require.Nil(resp) require.Nil(resp)
require.Error(err)
} }
// Don't find just by project // Don't find just by project

View File

@ -1,119 +1,187 @@
package state package state
import ( import (
"bytes"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"github.com/glebarez/sqlite"
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk" "github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server" "github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
serverptypes "github.com/hashicorp/vagrant/internal/server/ptypes" "github.com/imdario/mergo"
"github.com/mitchellh/go-testing-interface" "github.com/mitchellh/go-testing-interface"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
bolt "go.etcd.io/bbolt" "gorm.io/gorm"
"gorm.io/gorm/logger"
) )
// TestState returns an initialized State for testing. // TestState returns an initialized State for testing.
func TestState(t testing.T) *State { func TestState(t testing.T) *State {
result, err := New(hclog.L(), testDB(t))
require.NoError(t, err)
return result
}
// TestStateReinit reinitializes the state by pretending to restart
// the server with the database associated with this state. This can be
// used to test index init logic.
//
// This safely copies the entire DB so the old state can continue running
// with zero impact.
func TestStateReinit(t testing.T, s *State) *State {
// Copy the old database to a brand new path
td, err := ioutil.TempDir("", "test")
require.NoError(t, err)
t.Cleanup(func() { os.RemoveAll(td) })
path := filepath.Join(td, "test.db")
// Start db copy
require.NoError(t, s.db.View(func(tx *bolt.Tx) error {
return tx.CopyFile(path, 0600)
}))
// Open the new DB
db, err := bolt.Open(path, 0600, nil)
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
// Init new state
result, err := New(hclog.L(), db)
require.NoError(t, err)
return result
}
// TestStateRestart closes the given state and restarts it against the
// same DB file. Unlike TestStateReinit, this does not copy the data and
// the old state is no longer usable.
func TestStateRestart(t testing.T, s *State) (*State, error) {
path := s.db.Path()
require.NoError(t, s.Close())
// Open the new DB
db, err := bolt.Open(path, 0600, nil)
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
// Init new state
return New(hclog.L(), db)
}
func testDB(t testing.T) *bolt.DB {
t.Helper() t.Helper()
// Temporary directory for the database var buf bytes.Buffer
td, err := ioutil.TempDir("", "test") l := hclog.New(&hclog.LoggerOptions{
require.NoError(t, err) Name: "test",
t.Cleanup(func() { os.RemoveAll(td) }) Level: hclog.Trace,
Output: &buf,
IncludeLocation: true,
})
// Create the DB t.Cleanup(func() {
db, err := bolt.Open(filepath.Join(td, "test.db"), 0600, nil) t.Log(buf.String())
})
result, err := New(l, testDB(t))
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { db.Close() }) return result
}
// // TestStateReinit reinitializes the state by pretending to restart
// // the server with the database associated with this state. This can be
// // used to test index init logic.
// //
// // This safely copies the entire DB so the old state can continue running
// // with zero impact.
// func TestStateReinit(t testing.T, s *State) *State {
// // Copy the old database to a brand new path
// td, err := ioutil.TempDir("", "test")
// require.NoError(t, err)
// t.Cleanup(func() { os.RemoveAll(td) })
// path := filepath.Join(td, "test.db")
// // Start db copy
// require.NoError(t, s.db.View(func(tx *bolt.Tx) error {
// return tx.CopyFile(path, 0600)
// }))
// // Open the new DB
// db, err := bolt.Open(path, 0600, nil)
// require.NoError(t, err)
// t.Cleanup(func() { db.Close() })
// // Init new state
// result, err := New(hclog.L(), db)
// require.NoError(t, err)
// return result
// }
// // TestStateRestart closes the given state and restarts it against the
// // same DB file. Unlike TestStateReinit, this does not copy the data and
// // the old state is no longer usable.
// func TestStateRestart(t testing.T, s *State) (*State, error) {
// path := s.db.Path()
// require.NoError(t, s.Close())
// // Open the new DB
// db, err := bolt.Open(path, 0600, nil)
// require.NoError(t, err)
// t.Cleanup(func() { db.Close() })
// // Init new state
// return New(hclog.L(), db)
// }
func testDB(t testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(""), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err)
t.Cleanup(func() {
dbconn, err := db.DB()
if err == nil {
dbconn.Close()
}
})
return db return db
} }
// TestBasis creates the basis in the DB. // TestBasis creates the basis in the DB.
func testBasis(t testing.T, s *State) *vagrant_plugin_sdk.Ref_Basis { func testBasis(t testing.T, s *State) *vagrant_plugin_sdk.Ref_Basis {
t.Helper()
td := testTempDir(t) td := testTempDir(t)
s.BasisPut(serverptypes.TestBasis(t, &vagrant_server.Basis{ name := filepath.Base(td)
ResourceId: "test-basis", b := &Basis{
Path: td, Name: &name,
Name: "test-basis", Path: &td,
}))
return &vagrant_plugin_sdk.Ref_Basis{
ResourceId: "test-basis",
Path: td,
Name: "test-basis",
} }
result := s.db.Save(b)
require.NoError(t, result.Error)
return b.ToProtoRef()
} }
func testProject(t testing.T, s *State) *vagrant_plugin_sdk.Ref_Project { func testProject(t testing.T, s *State) *vagrant_plugin_sdk.Ref_Project {
t.Helper()
basisRef := testBasis(t, s) basisRef := testBasis(t, s)
s.ProjectPut(serverptypes.TestProject(t, &vagrant_server.Project{ b, err := s.BasisFromProtoRef(basisRef)
ResourceId: "test-project", require.NoError(t, err)
Basis: basisRef, td := testTempDir(t)
Path: "idontexist", name := filepath.Base(td)
Name: "test-project", p := &Project{
})) Name: &name,
return &vagrant_plugin_sdk.Ref_Project{ Path: &td,
ResourceId: "test-project", Basis: b,
Path: "idontexist",
Name: "test-project",
Basis: basisRef,
} }
result := s.db.Save(p)
require.NoError(t, result.Error)
return p.ToProtoRef()
}
func testRunner(t testing.T, s *State, src *vagrant_server.Runner) *vagrant_server.Runner {
t.Helper()
if src == nil {
src = &vagrant_server.Runner{}
}
id, err := server.Id()
require.NoError(t, err)
base := &vagrant_server.Runner{Id: id}
require.NoError(t, mergo.Merge(src, base))
var runner Runner
require.NoError(t, s.decode(src, &runner))
result := s.db.Save(&runner)
require.NoError(t, result.Error)
return runner.ToProto()
}
func testJob(t testing.T, src *vagrant_server.Job) *vagrant_server.Job {
t.Helper()
require.NoError(t, mergo.Merge(src,
&vagrant_server.Job{
TargetRunner: &vagrant_server.Ref_Runner{
Target: &vagrant_server.Ref_Runner_Any{
Any: &vagrant_server.Ref_RunnerAny{},
},
},
DataSource: &vagrant_server.Job_DataSource{
Source: &vagrant_server.Job_DataSource_Local{
Local: &vagrant_server.Job_Local{},
},
},
Operation: &vagrant_server.Job_Noop_{
Noop: &vagrant_server.Job_Noop{},
},
},
))
return src
} }
func testTempDir(t testing.T) string { func testTempDir(t testing.T) string {
t.Helper()
dir, err := ioutil.TempDir("", "vagrant-test") dir, err := ioutil.TempDir("", "vagrant-test")
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { os.RemoveAll(dir) }) t.Cleanup(func() { os.RemoveAll(dir) })

View File

@ -0,0 +1,68 @@
package state
import (
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
"gorm.io/gorm"
)
type VagrantfileFormat uint8
const (
JSON VagrantfileFormat = VagrantfileFormat(vagrant_server.Vagrantfile_JSON)
HCL = VagrantfileFormat(vagrant_server.Vagrantfile_HCL)
RUBY = VagrantfileFormat(vagrant_server.Vagrantfile_RUBY)
)
type Vagrantfile struct {
gorm.Model
Format VagrantfileFormat
Unfinalized *ProtoRaw
Finalized *ProtoRaw
Raw []byte
Path string
}
func init() {
models = append(models, &Vagrantfile{})
}
func (v *Vagrantfile) ToProto() *vagrant_server.Vagrantfile {
if v == nil {
return nil
}
return &vagrant_server.Vagrantfile{
Format: vagrant_server.Vagrantfile_Format(v.Format),
Raw: v.Raw,
Path: &vagrant_plugin_sdk.Args_Path{
Path: v.Path,
},
Unfinalized: v.Unfinalized.Message.(*vagrant_plugin_sdk.Args_Hash),
Finalized: v.Finalized.Message.(*vagrant_plugin_sdk.Args_Hash),
}
}
func (v *Vagrantfile) UpdateFromProto(vf *vagrant_server.Vagrantfile) *Vagrantfile {
v.Format = VagrantfileFormat(vf.Format)
v.Unfinalized = &ProtoRaw{Message: vf.Unfinalized}
v.Finalized = &ProtoRaw{Message: vf.Finalized}
v.Raw = vf.Raw
v.Path = vf.Path.Path
return v
}
func (s *State) VagrantfileFromProto(v *vagrant_server.Vagrantfile) *Vagrantfile {
file := &Vagrantfile{
Format: VagrantfileFormat(v.Format),
Unfinalized: &ProtoRaw{Message: v.Unfinalized},
Finalized: &ProtoRaw{Message: v.Finalized},
Raw: v.Raw,
}
if v.Path != nil {
file.Path = v.Path.Path
}
return file
}

View File

@ -0,0 +1,31 @@
package state
import (
"github.com/go-ozzo/ozzo-validation/v4"
"gorm.io/gorm"
)
type ValidationCode string
const (
VALIDATION_UNIQUE ValidationCode = "unique"
)
func checkUnique(tx *gorm.DB) validation.RuleFunc {
return func(value interface{}) error {
var count int64
result := tx.Count(&count)
if result.Error != nil {
return validation.NewInternalError(result.Error)
}
if count > 0 {
return validation.NewError(
string(VALIDATION_UNIQUE),
"must be unique",
)
}
return nil
}
}