Migrate data layer to gorm

This commit is contained in:
Chris Roberts 2022-09-19 08:40:11 -07:00
parent 0a0333adb7
commit f24ab4d855
24 changed files with 3870 additions and 2090 deletions

View File

@ -1,326 +1,340 @@
package state
import (
"strings"
"google.golang.org/protobuf/proto"
"github.com/hashicorp/go-memdb"
bolt "go.etcd.io/bbolt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"errors"
"github.com/go-ozzo/ozzo-validation/v4"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
"gorm.io/gorm"
)
var basisBucket = []byte("basis")
func init() {
dbBuckets = append(dbBuckets, basisBucket)
dbIndexers = append(dbIndexers, (*State).basisIndexInit)
schemas = append(schemas, basisIndexSchema)
models = append(models, &Basis{})
}
func (s *State) BasisFind(b *vagrant_server.Basis) (*vagrant_server.Basis, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
// This interface is utilized internally as an
// identifier for scopes to allow for easier mapping
type scope interface {
scope() interface{}
}
var result *vagrant_server.Basis
err := s.db.View(func(dbTxn *bolt.Tx) error {
var err error
result, err = s.basisFind(dbTxn, memTxn, b)
type Basis struct {
gorm.Model
Vagrantfile *Vagrantfile `mapstructure:"-"`
VagrantfileID uint `mapstructure:"-"`
DataSource *ProtoValue
Jobs []*InternalJob `gorm:"polymorphic:Scope;" mapstructure:"-"`
Metadata MetadataSet
Name *string `gorm:"uniqueIndex,not null"`
Path *string `gorm:"uniqueIndex,not null"`
Projects []*Project
RemoteEnabled bool
ResourceId *string `gorm:"<-:create;uniqueIndex;not null"`
}
func (b *Basis) scope() interface{} {
return b
}
// Define custom table name
func (Basis) TableName() string {
return "basis"
}
func (b *Basis) BeforeSave(tx *gorm.DB) error {
if b.ResourceId == nil {
if err := b.setId(); err != nil {
return err
})
return result, err
}
func (s *State) BasisGet(ref *vagrant_plugin_sdk.Ref_Basis) (*vagrant_server.Basis, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
var result *vagrant_server.Basis
err := s.db.View(func(dbTxn *bolt.Tx) error {
var err error
result, err = s.basisGet(dbTxn, memTxn, ref)
}
}
if err := b.Validate(tx); err != nil {
return err
})
return result, err
}
func (s *State) BasisPut(b *vagrant_server.Basis) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) error {
return s.basisPut(dbTxn, memTxn, b)
})
if err == nil {
memTxn.Commit()
}
return err
return nil
}
func (s *State) BasisDelete(ref *vagrant_plugin_sdk.Ref_Basis) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
func (b *Basis) Validate(tx *gorm.DB) error {
err := validation.ValidateStruct(b,
validation.Field(&b.Name,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Basis{}).
Where(&Basis{Name: b.Name}).
Not(&Basis{Model: gorm.Model{ID: b.ID}}),
),
),
),
validation.Field(&b.Path,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Basis{}).
Where(&Basis{Path: b.Path}).
Not(&Basis{Model: gorm.Model{ID: b.ID}}),
),
),
),
validation.Field(&b.ResourceId,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Basis{}).
Where(&Basis{ResourceId: b.ResourceId}).
Not(&Basis{Model: gorm.Model{ID: b.ID}}),
),
),
),
)
err := s.db.Update(func(dbTxn *bolt.Tx) error {
return s.basisDelete(dbTxn, memTxn, ref)
})
if err == nil {
memTxn.Commit()
if err != nil {
return err
}
return nil
}
func (b *Basis) setId() error {
id, err := server.Id()
if err != nil {
return err
}
b.ResourceId = &id
return nil
}
func (s *State) BasisList() ([]*vagrant_plugin_sdk.Ref_Basis, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
// Convert basis to protobuf message
func (b *Basis) ToProto() *vagrant_server.Basis {
if b == nil {
return nil
}
return s.basisList(memTxn)
basis := vagrant_server.Basis{}
err := decode(b, &basis)
if err != nil {
panic("failed to decode basis: " + err.Error())
}
if b.Vagrantfile != nil {
basis.Configuration = b.Vagrantfile.ToProto()
}
return &basis
}
func (s *State) basisGet(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
ref *vagrant_plugin_sdk.Ref_Basis,
) (*vagrant_server.Basis, error) {
var result vagrant_server.Basis
b := dbTxn.Bucket(basisBucket)
return &result, dbGet(b, s.basisIdByRef(ref), &result)
// Convert basis to reference protobuf message
func (b *Basis) ToProtoRef() *vagrant_plugin_sdk.Ref_Basis {
if b == nil {
return nil
}
ref := vagrant_plugin_sdk.Ref_Basis{}
err := decode(b, &ref)
if err != nil {
panic("failed to decode basis to ref: " + err.Error())
}
return &ref
}
func (s *State) basisFind(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
// Load a Basis from a protobuf message. This will only search
// against the resource id.
func (s *State) BasisFromProto(
b *vagrant_server.Basis,
) (*vagrant_server.Basis, error) {
var match *basisIndexRecord
// Start with the resource id first
if b.ResourceId != "" {
if raw, err := memTxn.First(
basisIndexTableName,
basisIndexIdIndexName,
b.ResourceId,
); raw != nil && err == nil {
match = raw.(*basisIndexRecord)
}
}
// Try the name next
if b.Name != "" && match == nil {
if raw, err := memTxn.First(
basisIndexTableName,
basisIndexNameIndexName,
b.Name,
); raw != nil && err == nil {
match = raw.(*basisIndexRecord)
}
}
// And finally the path
if b.Path != "" && match == nil {
if raw, err := memTxn.First(
basisIndexTableName,
basisIndexPathIndexName,
b.Path,
); raw != nil && err == nil {
match = raw.(*basisIndexRecord)
}
) (*Basis, error) {
if b == nil {
return nil, ErrEmptyProtoArgument
}
if match == nil {
return nil, status.Errorf(codes.NotFound, "record not found for Basis")
}
return s.basisGet(dbTxn, memTxn, &vagrant_plugin_sdk.Ref_Basis{
ResourceId: match.Id,
})
}
func (s *State) basisPut(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
value *vagrant_server.Basis,
) (err error) {
s.log.Trace("storing basis", "basis", value)
if value.ResourceId == "" {
s.log.Trace("basis has no resource id, assuming new basis",
"basis", value)
if value.ResourceId, err = s.newResourceId(); err != nil {
s.log.Error("failed to create resource id for basis", "basis", value,
"error", err)
return
}
}
s.log.Trace("storing basis to db", "basis", value)
id := s.basisId(value)
b := dbTxn.Bucket(basisBucket)
if err = dbPut(b, id, value); err != nil {
s.log.Error("failed to store basis in db", "basis", value, "error", err)
return
}
s.log.Trace("indexing basis", "basis", value)
if err = s.basisIndexSet(memTxn, id, value); err != nil {
s.log.Error("failed to index basis", "basis", value, "error", err)
return
}
return
}
func (s *State) basisList(
memTxn *memdb.Txn,
) ([]*vagrant_plugin_sdk.Ref_Basis, error) {
iter, err := memTxn.Get(basisIndexTableName, basisIndexIdIndexName+"_prefix", "")
basis, err := s.BasisFromProtoRef(
&vagrant_plugin_sdk.Ref_Basis{
ResourceId: b.ResourceId,
},
)
if err != nil {
return nil, err
}
var result []*vagrant_plugin_sdk.Ref_Basis
for {
next := iter.Next()
if next == nil {
break
}
idx := next.(*basisIndexRecord)
result = append(result, &vagrant_plugin_sdk.Ref_Basis{
ResourceId: idx.Id,
Name: idx.Name,
})
}
return result, nil
return basis, nil
}
func (s *State) basisDelete(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
ref *vagrant_plugin_sdk.Ref_Basis,
) error {
b, err := s.basisGet(dbTxn, memTxn, ref)
if err != nil {
if status.Code(err) == codes.NotFound {
return nil
}
return err
// Load a Basis from a protobuf message. This will attempt to locate the
// basis using any unique field it can match.
func (s *State) BasisFromProtoFuzzy(
b *vagrant_server.Basis,
) (*Basis, error) {
if b == nil {
return nil, ErrEmptyProtoArgument
}
for _, p := range b.Projects {
if err := s.projectDelete(dbTxn, memTxn, p); err != nil {
return err
}
}
// Delete from bolt
if err := dbTxn.Bucket(basisBucket).Delete(s.basisId(b)); err != nil {
return err
}
// Delete from memdb
record := s.newBasisIndexRecord(b)
if err := memTxn.Delete(basisIndexTableName, record); err != nil {
return err
}
return nil
}
func (s *State) basisIndexSet(txn *memdb.Txn, id []byte, value *vagrant_server.Basis) error {
return txn.Insert(basisIndexTableName, s.newBasisIndexRecord(value))
}
func (s *State) basisIndexInit(dbTxn *bolt.Tx, memTxn *memdb.Txn) error {
bucket := dbTxn.Bucket(basisBucket)
return bucket.ForEach(func(k, v []byte) error {
var value vagrant_server.Basis
if err := proto.Unmarshal(v, &value); err != nil {
return err
}
if err := s.basisIndexSet(memTxn, k, &value); err != nil {
return err
}
return nil
})
}
func basisIndexSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: basisIndexTableName,
Indexes: map[string]*memdb.IndexSchema{
basisIndexIdIndexName: {
Name: basisIndexIdIndexName,
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Id",
Lowercase: false,
},
},
basisIndexNameIndexName: {
Name: basisIndexNameIndexName,
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Name",
Lowercase: true,
},
},
basisIndexPathIndexName: {
Name: basisIndexPathIndexName,
AllowMissing: true,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Path",
Lowercase: false,
},
},
},
}
}
const (
basisIndexIdIndexName = "id"
basisIndexNameIndexName = "name"
basisIndexPathIndexName = "path"
basisIndexTableName = "basis-index"
)
type basisIndexRecord struct {
Id string
Name string
Path string
}
func (s *State) newBasisIndexRecord(b *vagrant_server.Basis) *basisIndexRecord {
return &basisIndexRecord{
Id: b.ResourceId,
Name: strings.ToLower(b.Name),
basis, err := s.BasisFromProtoRefFuzzy(
&vagrant_plugin_sdk.Ref_Basis{
ResourceId: b.ResourceId,
Name: b.Name,
Path: b.Path,
},
)
if err != nil {
return nil, err
}
return basis, nil
}
func (s *State) newBasisIndexRecordByRef(ref *vagrant_plugin_sdk.Ref_Basis) *basisIndexRecord {
return &basisIndexRecord{
Id: ref.ResourceId,
Name: strings.ToLower(ref.Name),
}
}
func (s *State) basisId(b *vagrant_server.Basis) []byte {
return []byte(b.ResourceId)
}
func (s *State) basisIdByRef(ref *vagrant_plugin_sdk.Ref_Basis) []byte {
// Load a Basis from a reference protobuf message
func (s *State) BasisFromProtoRef(
ref *vagrant_plugin_sdk.Ref_Basis,
) (*Basis, error) {
if ref == nil {
return []byte{}
return nil, ErrEmptyProtoArgument
}
return []byte(ref.ResourceId)
if ref.ResourceId == "" {
return nil, gorm.ErrRecordNotFound
}
var basis Basis
result := s.search().First(&basis, &Basis{ResourceId: &ref.ResourceId})
if result.Error != nil {
return nil, result.Error
}
return &basis, nil
}
func (s *State) BasisFromProtoRefFuzzy(
ref *vagrant_plugin_sdk.Ref_Basis,
) (*Basis, error) {
basis, err := s.BasisFromProtoRef(ref)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
if basis != nil {
return basis, nil
}
// If name and path are both empty, we can't search
if ref.Name == "" && ref.Path == "" {
return nil, gorm.ErrRecordNotFound
}
basis = &Basis{}
query := &Basis{}
if ref.Name != "" {
query.Name = &ref.Name
}
if ref.Path != "" {
query.Path = &ref.Path
}
result := s.search().First(basis, query)
if result.Error != nil {
return nil, result.Error
}
return basis, nil
}
// Get a basis record using a reference protobuf message.
func (s *State) BasisGet(
ref *vagrant_plugin_sdk.Ref_Basis,
) (*vagrant_server.Basis, error) {
b, err := s.BasisFromProtoRef(ref)
if err != nil {
return nil, lookupErrorToStatus("basis", err)
}
return b.ToProto(), nil
}
// Find a basis record using a protobuf message
func (s *State) BasisFind(
b *vagrant_server.Basis,
) (*vagrant_server.Basis, error) {
basis, err := s.BasisFromProtoFuzzy(b)
if err != nil {
return nil, lookupErrorToStatus("basis", err)
}
return basis.ToProto(), nil
}
// Store a basis record
func (s *State) BasisPut(
b *vagrant_server.Basis,
) (*vagrant_server.Basis, error) {
basis, err := s.BasisFromProto(b)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, lookupErrorToStatus("basis", err)
}
// Make sure we don't have a nil
if err != nil {
basis = &Basis{}
}
err = s.softDecode(b, basis)
if err != nil {
return nil, saveErrorToStatus("basis", err)
}
if b.Configuration != nil {
if basis.Vagrantfile != nil {
basis.Vagrantfile.UpdateFromProto(b.Configuration)
} else {
basis.Vagrantfile = s.VagrantfileFromProto(b.Configuration)
}
}
result := s.db.Save(basis)
if result.Error != nil {
return nil, saveErrorToStatus("basis", result.Error)
}
return basis.ToProto(), nil
}
// List all basis records
func (s *State) BasisList() ([]*Basis, error) {
var all []*Basis
result := s.search().Find(&all)
if result.Error != nil {
return nil, lookupErrorToStatus("basis", result.Error)
}
return all, nil
}
// Delete a basis
func (s *State) BasisDelete(
b *vagrant_plugin_sdk.Ref_Basis,
) error {
basis, err := s.BasisFromProtoRef(b)
// If the record was not found, we return with no error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return nil
}
// If an unexpected error was encountered, return it
if err != nil {
return lookupErrorToStatus("basis", err)
}
result := s.db.Delete(basis)
if result.Error != nil {
return deleteErrorToStatus("basis", result.Error)
}
return nil
}
var (
_ scope = (*Basis)(nil)
)

View File

@ -19,6 +19,42 @@ func TestBasis(t *testing.T) {
require.Error(err)
})
t.Run("Put creates and sets resource ID", func(t *testing.T) {
require := require.New(t)
s := TestState(t)
defer s.Close()
testBasis := &vagrant_server.Basis{
Name: "test_name",
Path: "/User/test/test",
}
result, err := s.BasisPut(testBasis)
require.NoError(err)
require.NotEmpty(result.ResourceId)
})
t.Run("Put fails on duplicate name", func(t *testing.T) {
require := require.New(t)
s := TestState(t)
defer s.Close()
testBasis := &vagrant_server.Basis{
Name: "test_name",
Path: "/User/test/test",
}
// Set initial record
_, err := s.BasisPut(testBasis)
require.NoError(err)
// Attempt to set it again
_, err = s.BasisPut(testBasis)
require.Error(err)
})
t.Run("Put and Get", func(t *testing.T) {
require := require.New(t)
@ -26,38 +62,23 @@ func TestBasis(t *testing.T) {
defer s.Close()
testBasis := &vagrant_server.Basis{
ResourceId: "test",
Name: "test_name",
Path: "/User/test/test",
}
testBasisRef := &vagrant_plugin_sdk.Ref_Basis{
ResourceId: "test",
Name: "test_name",
Path: "/User/test/test",
}
// Set
err := s.BasisPut(testBasis)
result, err := s.BasisPut(testBasis)
require.NoError(err)
testBasisRef := &vagrant_plugin_sdk.Ref_Basis{
ResourceId: result.ResourceId,
}
// Get full ref
{
resp, err := s.BasisGet(testBasisRef)
require.NoError(err)
require.NotNil(resp)
require.Equal(resp.Name, testBasis.Name)
}
// Get by id
{
resp, err := s.BasisGet(&vagrant_plugin_sdk.Ref_Basis{
ResourceId: "test",
})
require.NoError(err)
require.NotNil(resp)
require.Equal(resp.Name, testBasis.Name)
}
})
t.Run("Find", func(t *testing.T) {
@ -67,19 +88,18 @@ func TestBasis(t *testing.T) {
defer s.Close()
testBasis := &vagrant_server.Basis{
ResourceId: "test",
Name: "test_name",
Path: "/User/test/test",
}
// Set
err := s.BasisPut(testBasis)
result, err := s.BasisPut(testBasis)
require.NoError(err)
// Find by resource id
{
resp, err := s.BasisFind(&vagrant_server.Basis{
ResourceId: "test",
ResourceId: result.ResourceId,
})
require.NoError(err)
require.NotNil(resp)
@ -114,7 +134,6 @@ func TestBasis(t *testing.T) {
defer s.Close()
testBasis := &vagrant_server.Basis{
ResourceId: "test",
Name: "test_name",
Path: "/User/test/test",
}
@ -126,8 +145,9 @@ func TestBasis(t *testing.T) {
require.NoError(err)
// Add basis
err = s.BasisPut(testBasis)
result, err := s.BasisPut(testBasis)
require.NoError(err)
testBasisRef.ResourceId = result.ResourceId
// No error when deleting basis
err = s.BasisDelete(testBasisRef)
@ -145,15 +165,13 @@ func TestBasis(t *testing.T) {
defer s.Close()
// Add basis'
err := s.BasisPut(&vagrant_server.Basis{
ResourceId: "test",
_, err := s.BasisPut(&vagrant_server.Basis{
Name: "test_name",
Path: "/User/test/test",
})
require.NoError(err)
err = s.BasisPut(&vagrant_server.Basis{
ResourceId: "test2",
_, err = s.BasisPut(&vagrant_server.Basis{
Name: "test_name2",
Path: "/User/test/test2",
})

View File

@ -1,305 +1,361 @@
package state
import (
"google.golang.org/protobuf/proto"
"github.com/hashicorp/go-memdb"
"errors"
"fmt"
"time"
"github.com/go-ozzo/ozzo-validation/v4"
"github.com/go-ozzo/ozzo-validation/v4/is"
"github.com/hashicorp/go-version"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
bolt "go.etcd.io/bbolt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/mitchellh/mapstructure"
"gorm.io/gorm"
)
var boxBucket = []byte("box")
func init() {
dbBuckets = append(dbBuckets, boxBucket)
dbIndexers = append(dbIndexers, (*State).boxIndexInit)
schemas = append(schemas, boxIndexSchema)
models = append(models, &Box{})
}
const (
DEFAULT_BOX_VERSION = "0.0.0"
DEFAULT_BOX_CONSTRAINT = "> 0"
)
type Box struct {
gorm.Model
Directory *string `gorm:"not null"`
LastUpdate *time.Time `gorm:"autoUpdateTime"`
Metadata *ProtoValue
MetadataUrl *string
Name *string `gorm:"uniqueIndex:idx_nameverprov;not null"`
Provider *string `gorm:"uniqueIndex:idx_nameverprov;not null"`
ResourceId *string `gorm:"<-:create;uniqueIndex;not null"`
Version *string `gorm:"uniqueIndex:idx_nameverprov;not null"`
}
func (b *Box) BeforeSave(tx *gorm.DB) error {
if b.ResourceId == nil {
if err := b.setId(); err != nil {
return err
}
}
// If version is not set, default it to 0
if b.Version == nil || *b.Version == "0" {
v := DEFAULT_BOX_VERSION
b.Version = &v
}
if err := b.Validate(tx); err != nil {
return err
}
return nil
}
func (b *Box) setId() error {
id, err := server.Id()
if err != nil {
return err
}
b.ResourceId = &id
return nil
}
func (b *Box) Validate(tx *gorm.DB) error {
err := validation.ValidateStruct(b,
validation.Field(&b.Directory, validation.Required),
validation.Field(&b.Name, validation.Required),
validation.Field(&b.Provider, validation.Required),
validation.Field(&b.ResourceId,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Box{}).
Where(&Box{ResourceId: b.ResourceId}).
Not(&Box{Model: gorm.Model{ID: b.ID}}),
),
),
),
validation.Field(&b.Version,
validation.Required,
is.Semver,
),
)
if err != nil {
return err
}
err = validation.Validate(b,
validation.By(
checkUnique(
tx.Model(&Box{}).
Where(&Box{Name: b.Name, Provider: b.Provider, Version: b.Version}).
Not(&Box{Model: gorm.Model{ID: b.ID}}),
),
),
)
if err != nil {
return fmt.Errorf("name, provider and version %s", err)
}
return nil
}
func (b *Box) ToProto() *vagrant_server.Box {
var p vagrant_server.Box
err := decode(b, &p)
if err != nil {
panic(fmt.Sprintf("failed to decode box: " + err.Error()))
}
return &p
}
func (b *Box) ToProtoRef() *vagrant_plugin_sdk.Ref_Box {
var p vagrant_plugin_sdk.Ref_Box
err := decode(b, &p)
if err != nil {
panic(fmt.Sprintf("failed to decode box ref: " + err.Error()))
}
return &p
}
func (s *State) BoxFromProtoRef(
b *vagrant_plugin_sdk.Ref_Box,
) (*Box, error) {
if b == nil {
return nil, ErrEmptyProtoArgument
}
if b.ResourceId == "" {
return nil, gorm.ErrRecordNotFound
}
var box Box
result := s.search().First(&box, &Box{ResourceId: &b.ResourceId})
if result.Error != nil {
return nil, result.Error
}
return &box, nil
}
func (s *State) BoxFromProtoRefFuzzy(
b *vagrant_plugin_sdk.Ref_Box,
) (*Box, error) {
box, err := s.BoxFromProtoRef(b)
if err == nil {
return box, nil
}
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
if b.Name == "" || b.Provider == "" || b.Version == "" {
return nil, gorm.ErrRecordNotFound
}
box = &Box{}
result := s.search().First(box,
&Box{
Name: &b.Name,
Provider: &b.Provider,
Version: &b.Version,
},
)
if result.Error != nil {
return nil, result.Error
}
return box, nil
}
func (s *State) BoxFromProto(
b *vagrant_server.Box,
) (*Box, error) {
return s.BoxFromProtoRef(
&vagrant_plugin_sdk.Ref_Box{
ResourceId: b.ResourceId,
},
)
}
func (s *State) BoxFromProtoFuzzy(
b *vagrant_server.Box,
) (*Box, error) {
return s.BoxFromProtoRefFuzzy(
&vagrant_plugin_sdk.Ref_Box{
Name: b.Name,
Provider: b.Provider,
ResourceId: b.ResourceId,
Version: b.Version,
},
)
}
func (s *State) BoxList() ([]*vagrant_plugin_sdk.Ref_Box, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
return s.boxList(memTxn)
}
func (s *State) BoxDelete(ref *vagrant_plugin_sdk.Ref_Box) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) error {
return s.boxDelete(dbTxn, memTxn, ref)
})
if err == nil {
memTxn.Commit()
var boxes []Box
result := s.db.Find(&boxes)
if result.Error != nil {
return nil, lookupErrorToStatus("boxes", result.Error)
}
return err
}
func (s *State) BoxGet(ref *vagrant_plugin_sdk.Ref_Box) (*vagrant_server.Box, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
var result *vagrant_server.Box
err := s.db.View(func(dbTxn *bolt.Tx) error {
var err error
result, err = s.boxGet(dbTxn, memTxn, ref)
return err
})
return result, err
}
func (s *State) BoxPut(box *vagrant_server.Box) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) error {
return s.boxPut(dbTxn, memTxn, box)
})
if err == nil {
memTxn.Commit()
refs := make([]*vagrant_plugin_sdk.Ref_Box, len(boxes))
for i, b := range boxes {
refs[i] = b.ToProtoRef()
}
return err
return refs, nil
}
func (s *State) BoxFind(b *vagrant_plugin_sdk.Ref_Box) (*vagrant_server.Box, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
var result *vagrant_server.Box
err := s.db.View(func(dbTxn *bolt.Tx) error {
var err error
result, err = s.boxFind(dbTxn, memTxn, b)
return err
})
return result, err
}
func (s *State) boxList(
memTxn *memdb.Txn,
) (r []*vagrant_plugin_sdk.Ref_Box, err error) {
iter, err := memTxn.Get(boxIndexTableName, boxIndexIdIndexName+"_prefix", "")
if err != nil {
return nil, err
}
var result []*vagrant_plugin_sdk.Ref_Box
for {
next := iter.Next()
if next == nil {
break
}
result = append(result, &vagrant_plugin_sdk.Ref_Box{
ResourceId: next.(*boxIndexRecord).Id,
Name: next.(*boxIndexRecord).Name,
Version: next.(*boxIndexRecord).Version,
Provider: next.(*boxIndexRecord).Provider,
})
}
return result, nil
}
func (s *State) boxDelete(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
ref *vagrant_plugin_sdk.Ref_Box,
) (err error) {
b, err := s.boxGet(dbTxn, memTxn, ref)
if err != nil {
if status.Code(err) == codes.NotFound {
func (s *State) BoxDelete(
b *vagrant_plugin_sdk.Ref_Box,
) error {
box, err := s.BoxFromProtoRef(b)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return nil
}
return err
if err != nil {
return deleteErrorToStatus("box", err)
}
// Delete the box
if err = dbTxn.Bucket(boxBucket).Delete(s.boxId(b)); err != nil {
return
result := s.db.Delete(box)
if result.Error != nil {
return deleteErrorToStatus("box", result.Error)
}
if err = memTxn.Delete(boxIndexTableName, s.newBoxIndexRecord(b)); err != nil {
return
}
return
return nil
}
func (s *State) boxGet(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
func (s *State) BoxGet(
b *vagrant_plugin_sdk.Ref_Box,
) (*vagrant_server.Box, error) {
box, err := s.BoxFromProtoRef(b)
if err != nil {
return nil, lookupErrorToStatus("box", err)
}
return box.ToProto(), nil
}
func (s *State) BoxPut(b *vagrant_server.Box) error {
box, err := s.BoxFromProto(b)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return lookupErrorToStatus("box", err)
}
if err != nil {
box = &Box{}
}
err = s.softDecode(b, box)
if err != nil {
return saveErrorToStatus("box", err)
}
result := s.db.Save(box)
if result.Error != nil {
return saveErrorToStatus("box", result.Error)
}
return nil
}
func (s *State) BoxFind(
ref *vagrant_plugin_sdk.Ref_Box,
) (r *vagrant_server.Box, err error) {
var result vagrant_server.Box
b := dbTxn.Bucket(boxBucket)
return &result, dbGet(b, s.boxIdByRef(ref), &result)
}
func (s *State) boxPut(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
value *vagrant_server.Box,
) (err error) {
id := s.boxId(value)
b := dbTxn.Bucket(boxBucket)
if err = dbPut(b, id, value); err != nil {
s.log.Error("failed to store box in db", "box", value, "error", err)
return
) (*vagrant_server.Box, error) {
b := &vagrant_plugin_sdk.Ref_Box{}
if err := mapstructure.Decode(ref, b); err != nil {
return nil, lookupErrorToStatus("box", err)
}
s.log.Trace("indexing box", "box", value)
if err = s.boxIndexSet(memTxn, id, value); err != nil {
s.log.Error("failed to index box", "box", value, "error", err)
return
if b.ResourceId != "" {
box, err := s.BoxFromProtoRef(b)
if err != nil {
return nil, lookupErrorToStatus("box", err)
}
return box.ToProto(), nil
}
return
}
// If no name is given, we error immediately
if b.Name == "" {
return nil, lookupErrorToStatus("box", fmt.Errorf("no name given for box lookup"))
}
// If no provider is given, we error immediately
if b.Provider == "" {
return nil, lookupErrorToStatus("box", fmt.Errorf("no provider given for box lookup"))
}
func (s *State) boxFind(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
ref *vagrant_plugin_sdk.Ref_Box,
) (r *vagrant_server.Box, err error) {
var match *boxIndexRecord
highestVersion, _ := version.NewVersion("0.0.0")
req := s.newBoxIndexRecordByRef(ref)
// Get the name first
if req.Name != "" {
raw, err := memTxn.Get(
boxIndexTableName,
boxIndexNameIndexName,
req.Name,
// If the version is set to 0, mark it as default
if b.Version == "0" {
b.Version = DEFAULT_BOX_VERSION
}
// If we are provided an explicit version, just do a direct lookup
if _, err := version.NewVersion(b.Version); err == nil {
box, err := s.BoxFromProtoRefFuzzy(b)
if err != nil {
return nil, lookupErrorToStatus("box", err)
}
return box.ToProto(), nil
}
var boxes []Box
result := s.search().Find(&boxes,
&Box{
Name: &b.Name,
Provider: &b.Provider,
},
)
if err != nil {
return nil, err
}
if req.Version == "" {
req.Version = ">= 0"
}
versionConstraint, err := version.NewConstraint(req.Version)
if err != nil {
return nil, err
if result.Error != nil {
return nil, lookupErrorToStatus("box", result.Error)
}
for e := raw.Next(); e != nil; e = raw.Next() {
boxIndexEntry := e.(*boxIndexRecord)
if req.Version != "" {
boxVersion, _ := version.NewVersion(boxIndexEntry.Version)
// If we found no boxes, return a not found error
if len(boxes) < 1 {
return nil, lookupErrorToStatus("box", gorm.ErrRecordNotFound)
}
// If we have no version value set, apply the default
// version constraint
if b.Version == "" {
b.Version = DEFAULT_BOX_CONSTRAINT
}
var match *Box
highestVersion, _ := version.NewVersion("0.0.0")
versionConstraint, err := version.NewConstraint(b.Version)
if err != nil {
return nil, lookupErrorToStatus("box", err)
}
for _, box := range boxes {
boxVersion, err := version.NewVersion(*box.Version)
if err != nil {
return nil, lookupErrorToStatus("box", err)
}
if !versionConstraint.Check(boxVersion) {
continue
}
}
if req.Provider != "" {
if boxIndexEntry.Provider != req.Provider {
continue
}
}
// Set first match
if match == nil {
match = boxIndexEntry
}
v, _ := version.NewVersion(boxIndexEntry.Version)
if v.GreaterThan(highestVersion) {
highestVersion = v
match = boxIndexEntry
if boxVersion.GreaterThan(highestVersion) {
match = &box
highestVersion = boxVersion
}
}
if match != nil {
return s.boxGet(dbTxn, memTxn, &vagrant_plugin_sdk.Ref_Box{
ResourceId: match.Id,
})
}
return match.ToProto(), nil
}
return
}
const (
boxIndexIdIndexName = "id"
boxIndexNameIndexName = "name"
boxIndexTableName = "box-index"
)
type boxIndexRecord struct {
Id string // Resource ID
Name string // Box Name
Version string // Box Version
Provider string // Box Provider
}
func (s *State) newBoxIndexRecord(b *vagrant_server.Box) *boxIndexRecord {
id := b.Name + "-" + b.Version + "-" + b.Provider
return &boxIndexRecord{
Id: id,
Name: b.Name,
Version: b.Version,
Provider: b.Provider,
}
}
func (s *State) boxIndexSet(txn *memdb.Txn, id []byte, value *vagrant_server.Box) error {
return txn.Insert(boxIndexTableName, s.newBoxIndexRecord(value))
}
func (s *State) boxIndexInit(dbTxn *bolt.Tx, memTxn *memdb.Txn) error {
bucket := dbTxn.Bucket(boxBucket)
return bucket.ForEach(func(k, v []byte) error {
var value vagrant_server.Box
if err := proto.Unmarshal(v, &value); err != nil {
return err
}
if err := s.boxIndexSet(memTxn, k, &value); err != nil {
return err
}
return nil
})
}
func boxIndexSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: boxIndexTableName,
Indexes: map[string]*memdb.IndexSchema{
boxIndexIdIndexName: {
Name: boxIndexIdIndexName,
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Id",
Lowercase: true,
},
},
boxIndexNameIndexName: {
Name: boxIndexNameIndexName,
AllowMissing: false,
Unique: false,
Indexer: &memdb.StringFieldIndex{
Field: "Name",
Lowercase: true,
},
},
},
}
}
func (s *State) newBoxIndexRecordByRef(ref *vagrant_plugin_sdk.Ref_Box) *boxIndexRecord {
return &boxIndexRecord{
Id: ref.ResourceId,
Name: ref.Name,
Version: ref.Version,
Provider: ref.Provider,
}
}
func (s *State) boxId(b *vagrant_server.Box) []byte {
return []byte(b.Id)
}
func (s *State) boxIdByRef(b *vagrant_plugin_sdk.Ref_Box) []byte {
return []byte(b.ResourceId)
return nil, lookupErrorToStatus("box", gorm.ErrRecordNotFound)
}

View File

@ -26,12 +26,21 @@ func TestBox(t *testing.T) {
defer s.Close()
testBox := &vagrant_server.Box{
Id: "qwerwasdf",
ResourceId: "qwerwasdf",
Directory: "/directory",
Name: "hashicorp/bionic",
Version: "1.2.3",
Provider: "virtualbox",
}
testBox2 := &vagrant_server.Box{
ResourceId: "qwerwasdf-2",
Directory: "/directory-2",
Name: "hashicorp/bionic",
Version: "1.2.4",
Provider: "virtualbox",
}
testBoxRef := &vagrant_plugin_sdk.Ref_Box{
ResourceId: "qwerwasdf",
Name: "hashicorp/bionic",
@ -39,9 +48,18 @@ func TestBox(t *testing.T) {
Provider: "virtualbox",
}
testBoxRef2 := &vagrant_plugin_sdk.Ref_Box{
ResourceId: "qwerwasdf-2",
Name: "hashicorp/bionic",
Version: "1.2.4",
Provider: "virtualbox",
}
// Set
err := s.BoxPut(testBox)
require.NoError(err)
err = s.BoxPut(testBox2)
require.NoError(err)
// Get full ref
{
@ -49,6 +67,11 @@ func TestBox(t *testing.T) {
require.NoError(err)
require.NotNil(resp)
require.Equal(resp.Name, testBox.Name)
resp, err = s.BoxGet(testBoxRef2)
require.NoError(err)
require.NotNil(resp)
require.Equal(resp.Name, testBox2.Name)
}
// Get by id
@ -69,7 +92,8 @@ func TestBox(t *testing.T) {
defer s.Close()
testBox := &vagrant_server.Box{
Id: "qwerwasdf",
ResourceId: "qwerwasdf",
Directory: "/directory",
Name: "hashicorp/bionic",
Version: "1.2.3",
Provider: "virtualbox",
@ -98,7 +122,8 @@ func TestBox(t *testing.T) {
defer s.Close()
err := s.BoxPut(&vagrant_server.Box{
Id: "qwerwasdf",
ResourceId: "qwerwasdf",
Directory: "/directory",
Name: "hashicorp/bionic",
Version: "1.2.3",
Provider: "virtualbox",
@ -106,7 +131,8 @@ func TestBox(t *testing.T) {
require.NoError(err)
err = s.BoxPut(&vagrant_server.Box{
Id: "rrbrwasdf",
ResourceId: "rrbrwasdf",
Directory: "/other-directory",
Name: "hashicorp/bionic",
Version: "1.2.4",
Provider: "virtualbox",
@ -125,7 +151,8 @@ func TestBox(t *testing.T) {
defer s.Close()
err := s.BoxPut(&vagrant_server.Box{
Id: "hashicorp/bionic-1.2.3-virtualbox",
ResourceId: "hashicorp/bionic-1.2.3-virtualbox",
Directory: "/directory",
Name: "hashicorp/bionic",
Version: "1.2.3",
Provider: "virtualbox",
@ -133,7 +160,8 @@ func TestBox(t *testing.T) {
require.NoError(err)
err = s.BoxPut(&vagrant_server.Box{
Id: "hashicorp/bionic-1.2.4-virtualbox",
ResourceId: "hashicorp/bionic-1.2.4-virtualbox",
Directory: "/other-directory",
Name: "hashicorp/bionic",
Version: "1.2.4",
Provider: "virtualbox",
@ -141,7 +169,8 @@ func TestBox(t *testing.T) {
require.NoError(err)
err = s.BoxPut(&vagrant_server.Box{
Id: "box-0-virtualbox",
ResourceId: "box-0-virtualbox",
Directory: "/another-directory",
Name: "box",
Version: "0",
Provider: "virtualbox",
@ -150,19 +179,12 @@ func TestBox(t *testing.T) {
b, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
Name: "hashicorp/bionic",
Provider: "virtualbox",
})
require.NoError(err)
require.Equal(b.Name, "hashicorp/bionic")
require.Equal(b.Version, "1.2.4")
b2, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
Name: "hashicorp/bionic",
Version: "1.2.3",
})
require.NoError(err)
require.Equal(b2.Name, "hashicorp/bionic")
require.Equal(b2.Version, "1.2.3")
b3, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
Name: "hashicorp/bionic",
Version: "1.2.3",
@ -178,7 +200,7 @@ func TestBox(t *testing.T) {
Version: "1.2.3",
Provider: "dontexist",
})
require.NoError(err)
require.Error(err)
require.Nil(b4)
b5, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
@ -186,23 +208,24 @@ func TestBox(t *testing.T) {
Version: "9.9.9",
Provider: "virtualbox",
})
require.NoError(err)
require.Error(err)
require.Nil(b5)
b6, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
Version: "1.2.3",
})
require.NoError(err)
require.Error(err)
require.Nil(b6)
b7, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
Name: "dontexist",
})
require.NoError(err)
require.Error(err)
require.Nil(b7)
b8, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
Name: "hashicorp/bionic",
Provider: "virtualbox",
Version: "~> 1.2",
})
require.NoError(err)
@ -211,6 +234,7 @@ func TestBox(t *testing.T) {
b9, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
Name: "hashicorp/bionic",
Provider: "virtualbox",
Version: "> 1.0, < 3.0",
})
require.NoError(err)
@ -221,15 +245,16 @@ func TestBox(t *testing.T) {
Name: "hashicorp/bionic",
Version: "< 1.0",
})
require.NoError(err)
require.Error(err)
require.Nil(b10)
b11, err := s.BoxFind(&vagrant_plugin_sdk.Ref_Box{
Name: "box",
Version: "0",
Provider: "virtualbox",
})
require.NoError(err)
require.Equal(b11.Name, "box")
require.Equal(b11.Version, "0")
require.Equal(b11.Version, "0.0.0")
})
}

View File

@ -0,0 +1,70 @@
package state
import (
"github.com/hashicorp/vagrant-plugin-sdk/component"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
"gorm.io/gorm"
)
type Component struct {
gorm.Model
Name string `gorm:"uniqueIndex:idx_stname"`
ServerAddr string `gorm:"uniqueIndex:idx_stname"`
Type component.Type `gorm:"uniqueIndex:idx_stname"`
Runners []*Runner `gorm:"many2many:runner_components"`
}
func init() {
models = append(models, &Component{})
}
func (c *Component) ToProtoRef() *vagrant_server.Ref_Component {
if c == nil {
return nil
}
return &vagrant_server.Ref_Component{
Type: vagrant_server.Component_Type(c.Type),
Name: c.Name,
}
}
func (c *Component) ToProto() *vagrant_server.Component {
if c == nil {
return nil
}
return &vagrant_server.Component{
Type: vagrant_server.Component_Type(c.Type),
Name: c.Name,
ServerAddr: c.ServerAddr,
}
}
func (s *State) ComponentFromProto(p *vagrant_server.Component) (*Component, error) {
var c Component
result := s.db.First(&c, &Component{
Name: p.Name,
ServerAddr: p.ServerAddr,
Type: component.Type(p.Type),
})
if result.Error == nil {
return &c, nil
}
if result.Error == gorm.ErrRecordNotFound {
c.Name = p.Name
c.ServerAddr = p.ServerAddr
c.Type = component.Type(p.Type)
result = s.db.Save(&c)
if result.Error != nil {
return nil, result.Error
}
return &c, nil
}
return nil, result.Error
}

View File

@ -1,40 +1,72 @@
package state
// TODO(spox): When dealing with the scopes on the configvar protos,
// we need to do lookups + fillins to populate parents so we index
// them correctly in memory and can properly do lookups
import (
"errors"
"fmt"
"sort"
"google.golang.org/protobuf/proto"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
bolt "go.etcd.io/bbolt"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
serversort "github.com/hashicorp/vagrant/internal/server/sort"
"gorm.io/gorm"
)
var configBucket = []byte("config")
type Config struct {
gorm.Model
Cid *string `gorm:"uniqueIndex"`
Name string
Scope *ProtoValue // TODO(spox): polymorphic needs to allow for runner
Value string
}
func init() {
dbBuckets = append(dbBuckets, configBucket)
models = append(models, &Config{})
dbIndexers = append(dbIndexers, (*State).configIndexInit)
schemas = append(schemas, configIndexSchema)
}
func (c *Config) ToProto() *vagrant_server.ConfigVar {
if c == nil {
return nil
}
var config vagrant_server.ConfigVar
if err := decode(c, &config); err != nil {
panic("failed to decode config: " + err.Error())
}
return &config
}
func (s *State) ConfigFromProto(p *vagrant_server.ConfigVar) (*Config, error) {
var c Config
cid := string(s.configVarId(p))
result := s.db.First(&c, &Config{Cid: &cid})
if result.Error != nil {
return nil, result.Error
}
return &c, nil
}
// ConfigSet writes a configuration variable to the data store.
func (s *State) ConfigSet(vs ...*vagrant_server.ConfigVar) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
var err error
err := s.db.Update(func(dbTxn *bolt.Tx) error {
for _, v := range vs {
if err := s.configSet(dbTxn, memTxn, v); err != nil {
if err := s.configSet(memTxn, v); err != nil {
return err
}
}
return nil
})
if err == nil {
memTxn.Commit()
}
@ -53,63 +85,91 @@ func (s *State) ConfigGetWatch(req *vagrant_server.ConfigGetRequest, ws memdb.Wa
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
var result []*vagrant_server.ConfigVar
err := s.db.View(func(dbTxn *bolt.Tx) error {
var err error
result, err = s.configGetMerged(dbTxn, memTxn, ws, req)
return err
})
return result, err
return s.configGetMerged(memTxn, ws, req)
}
func (s *State) configSet(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
value *vagrant_server.ConfigVar,
) error {
id := s.configVarId(value)
// Get the global bucket and write the value to it.
b := dbTxn.Bucket(configBucket)
if value.Value == "" {
if err := b.Delete(id); err != nil {
// Persist the configuration in the db
c, err := s.ConfigFromProto(value)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
} else {
if err := dbPut(b, id, value); err != nil {
return err
if err != nil {
cid := string(id)
c = &Config{Cid: &cid}
}
if err = s.softDecode(value, c); err != nil {
return saveErrorToStatus("config", err)
}
result := s.db.Save(c)
if result.Error != nil {
return saveErrorToStatus("config", result.Error)
}
// Create our index value and write that.
return s.configIndexSet(memTxn, id, value)
if err = s.configIndexSet(memTxn, id, value); err != nil {
return saveErrorToStatus("config", err)
}
return nil
}
func (s *State) configGetMerged(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
ws memdb.WatchSet,
req *vagrant_server.ConfigGetRequest,
) ([]*vagrant_server.ConfigVar, error) {
var mergeSet [][]*vagrant_server.ConfigVar
switch scope := req.Scope.(type) {
case *vagrant_server.ConfigGetRequest_Basis:
// For basis scope, we just return the basis scoped values
return s.configGetExact(memTxn, ws, scope.Basis, req.Prefix)
case *vagrant_server.ConfigGetRequest_Project:
// For project scope, we just return the project scoped values.
return s.configGetExact(dbTxn, memTxn, ws, scope.Project, req.Prefix)
// TODO(spox): this should be a "something" (do we allow config for any machine,project,basis?)
// case *vagrant_server.ConfigGetRequest_Application:
// For project scope, we collect project and basis values
m, err := s.configGetExact(memTxn, ws, scope.Project.Basis, req.Prefix)
if err != nil {
return nil, err
}
mergeSet = append(mergeSet, m)
m, err = s.configGetExact(memTxn, ws, scope.Project, req.Prefix)
if err != nil {
return nil, err
}
mergeSet = append(mergeSet, m)
case *vagrant_server.ConfigGetRequest_Target:
// For project scope, we collect project and basis values
m, err := s.configGetExact(memTxn, ws, scope.Target.Project.Basis, req.Prefix)
if err != nil {
return nil, err
}
mergeSet = append(mergeSet, m)
m, err = s.configGetExact(memTxn, ws, scope.Target.Project, req.Prefix)
if err != nil {
return nil, err
}
mergeSet = append(mergeSet, m)
m, err = s.configGetExact(memTxn, ws, scope.Target, req.Prefix)
if err != nil {
return nil, err
}
mergeSet = append(mergeSet, m)
case *vagrant_server.ConfigGetRequest_Runner:
var err error
mergeSet, err = s.configGetRunner(dbTxn, memTxn, ws, scope.Runner, req.Prefix)
mergeSet, err = s.configGetRunner(memTxn, ws, scope.Runner, req.Prefix)
if err != nil {
return nil, err
}
default:
panic("unknown scope")
return nil, fmt.Errorf("unknown scope type provided (%T)", req.Scope)
}
// Merge our merge set
@ -132,34 +192,54 @@ func (s *State) configGetMerged(
// configGetExact returns the list of config variables for a scope
// exactly. By "exactly" we mean without any merging logic: if you request
// app-scoped variables, you'll get app-scoped variables. If a project-scoped
// target-scoped variables, you'll get target-scoped variables. If a project-scoped
// variable matches, it will not be merged in.
func (s *State) configGetExact(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
ws memdb.WatchSet,
ref interface{}, // should be one of the *vagrant_server.Ref_ values.
ref interface{}, // should be one of the *vagrant_plugin_sdk.Ref_ or *vagrant_server.Ref_ values.
prefix string,
) ([]*vagrant_server.ConfigVar, error) {
// We have to get the correct iterator based on the scope. We check the
// scope and use the proper index to get the iterator here.
var iter memdb.ResultIterator
switch ref := ref.(type) {
case *vagrant_plugin_sdk.Ref_Project:
var err error
switch v := ref.(type) {
case *vagrant_plugin_sdk.Ref_Basis:
iter, err = memTxn.Get(
configIndexTableName,
configIndexProjectIndexName+"_prefix",
ref.ResourceId,
configIndexIdIndexName+"_prefix", // Enable a prefix match on lookup
fmt.Sprintf("%s/%s", v.ResourceId, prefix),
)
if err != nil {
return nil, err
}
case *vagrant_plugin_sdk.Ref_Project:
iter, err = memTxn.Get(
configIndexTableName,
configIndexIdIndexName+"_prefix", // Enable a prefix match on lookup
fmt.Sprintf("%s/%s/%s", v.Basis.ResourceId, v.ResourceId, prefix),
)
if err != nil {
return nil, err
}
case *vagrant_plugin_sdk.Ref_Target:
iter, err = memTxn.Get(
configIndexTableName,
configIndexIdIndexName+"_prefix", // Enable a prefix match on lookup
fmt.Sprintf("%s/%s/%s/%s",
v.Project.Basis.ResourceId,
v.Project.ResourceId,
v.ResourceId,
prefix,
),
)
if err != nil {
return nil, err
}
default:
panic("unknown scope")
return nil, fmt.Errorf("unknown scope type provided (%T)", ref)
}
// Add to our watchset
@ -167,20 +247,20 @@ func (s *State) configGetExact(
// Go through the iterator and accumulate the results
var result []*vagrant_server.ConfigVar
b := dbTxn.Bucket(configBucket)
for {
current := iter.Next()
if current == nil {
break
}
var value vagrant_server.ConfigVar
var value Config
record := current.(*configIndexRecord)
if err := dbGet(b, []byte(record.Id), &value); err != nil {
return nil, err
res := s.db.First(&value, &Config{Cid: &record.Id})
if res.Error != nil {
return nil, res.Error
}
result = append(result, &value)
result = append(result, value.ToProto())
}
return result, nil
@ -188,7 +268,6 @@ func (s *State) configGetExact(
// configGetRunner gets the config vars for a runner.
func (s *State) configGetRunner(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
ws memdb.WatchSet,
req *vagrant_server.Ref_RunnerId,
@ -196,7 +275,7 @@ func (s *State) configGetRunner(
) ([][]*vagrant_server.ConfigVar, error) {
iter, err := memTxn.Get(
configIndexTableName,
configIndexRunnerIndexName+"_prefix",
configIndexRunnerIndexName+"_prefix", // Enable a prefix match on lookup
true,
prefix,
)
@ -214,8 +293,6 @@ func (s *State) configGetRunner(
idxId = 1
)
// Go through the iterator and accumulate the results
b := dbTxn.Bucket(configBucket)
for {
current := iter.Next()
if current == nil {
@ -240,12 +317,13 @@ func (s *State) configGetRunner(
return nil, fmt.Errorf("config has unknown target type: %T", record.RunnerRef.Target)
}
var value vagrant_server.ConfigVar
if err := dbGet(b, []byte(record.Id), &value); err != nil {
return nil, err
var value Config
res := s.db.First(&value, &Config{Cid: &record.Id})
if res.Error != nil {
return nil, res.Error
}
result[idx] = append(result[idx], &value)
result[idx] = append(result[idx], value.ToProto())
}
return result, nil
@ -253,26 +331,26 @@ func (s *State) configGetRunner(
// configIndexSet writes an index record for a single config var.
func (s *State) configIndexSet(txn *memdb.Txn, id []byte, value *vagrant_server.ConfigVar) error {
var project, application string
var basis, project, target string
var runner *vagrant_server.Ref_Runner
switch scope := value.Scope.(type) {
//TODO(spox): Does this need to be machine? Need basis too?
//case *vagrant_server.ConfigVar_Application:
case *vagrant_server.ConfigVar_Basis:
basis = scope.Basis.ResourceId
case *vagrant_server.ConfigVar_Project:
project = scope.Project.ResourceId
case *vagrant_server.ConfigVar_Target:
target = scope.Target.ResourceId
case *vagrant_server.ConfigVar_Runner:
runner = scope.Runner
default:
panic("unknown scope")
}
record := &configIndexRecord{
Id: string(id),
Basis: basis,
Project: project,
Application: application,
Target: target,
Name: value.Name,
Runner: runner != nil,
RunnerRef: runner,
@ -288,33 +366,48 @@ func (s *State) configIndexSet(txn *memdb.Txn, id []byte, value *vagrant_server.
}
// configIndexInit initializes the config index from persisted data.
func (s *State) configIndexInit(dbTxn *bolt.Tx, memTxn *memdb.Txn) error {
bucket := dbTxn.Bucket(configBucket)
return bucket.ForEach(func(k, v []byte) error {
var value vagrant_server.ConfigVar
if err := proto.Unmarshal(v, &value); err != nil {
func (s *State) configIndexInit(memTxn *memdb.Txn) error {
var cfgs []Config
result := s.db.Find(&cfgs)
if result.Error != nil {
return result.Error
}
for _, c := range cfgs {
p := c.ToProto()
if err := s.configIndexSet(memTxn, s.configVarId(p), p); err != nil {
return err
}
if err := s.configIndexSet(memTxn, k, &value); err != nil {
return err
}
return nil
})
}
func (s *State) configVarId(v *vagrant_server.ConfigVar) []byte {
switch scope := v.Scope.(type) {
// TODO(spox): same as above with machine/basis/etc
//case *vagrant_server.ConfigVar_Application:
case *vagrant_server.ConfigVar_Project:
return []byte(fmt.Sprintf("%s/%s/%s",
scope.Project.ResourceId,
"",
case *vagrant_server.ConfigVar_Basis:
return []byte(
fmt.Sprintf("%v/%v",
scope.Basis.Name,
v.Name,
))
),
)
case *vagrant_server.ConfigVar_Project:
return []byte(
fmt.Sprintf("%v/%v/%v",
scope.Project.Basis.ResourceId,
scope.Project.ResourceId,
v.Name,
),
)
case *vagrant_server.ConfigVar_Target:
return []byte(
fmt.Sprintf("%v/%v/%v/%v",
scope.Target.Project.Basis.ResourceId,
scope.Target.Project.ResourceId,
scope.Target.ResourceId,
v.Name,
),
)
case *vagrant_server.ConfigVar_Runner:
var t string
switch scope.Runner.Target.(type) {
@ -345,8 +438,27 @@ func configIndexSchema() *memdb.TableSchema {
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Id",
Lowercase: false,
},
},
configIndexBasisIndexName: {
Name: configIndexBasisIndexName,
AllowMissing: true,
Unique: false,
Indexer: &memdb.CompoundIndex{
Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "Basis",
Lowercase: true,
},
&memdb.StringFieldIndex{
Field: "Name",
Lowercase: true,
},
},
},
},
configIndexProjectIndexName: {
@ -355,6 +467,11 @@ func configIndexSchema() *memdb.TableSchema {
Unique: false,
Indexer: &memdb.CompoundIndex{
Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "Basis",
Lowercase: true,
},
&memdb.StringFieldIndex{
Field: "Project",
Lowercase: true,
@ -368,19 +485,24 @@ func configIndexSchema() *memdb.TableSchema {
},
},
configIndexApplicationIndexName: {
Name: configIndexApplicationIndexName,
configIndexTargetIndexName: {
Name: configIndexTargetIndexName,
AllowMissing: true,
Unique: false,
Indexer: &memdb.CompoundIndex{
Indexes: []memdb.Indexer{
&memdb.StringFieldIndex{
Field: "Basis",
Lowercase: true,
},
&memdb.StringFieldIndex{
Field: "Project",
Lowercase: true,
},
&memdb.StringFieldIndex{
Field: "Application",
Field: "Target",
Lowercase: true,
},
@ -416,15 +538,17 @@ func configIndexSchema() *memdb.TableSchema {
const (
configIndexTableName = "config-index"
configIndexIdIndexName = "id"
configIndexBasisIndexName = "basis"
configIndexProjectIndexName = "project"
configIndexApplicationIndexName = "application"
configIndexTargetIndexName = "target"
configIndexRunnerIndexName = "runner"
)
type configIndexRecord struct {
Id string
Basis string
Project string
Application string
Target string
Name string
Runner bool // true if this is a runner config
RunnerRef *vagrant_server.Ref_Runner

View File

@ -5,10 +5,8 @@ import (
"time"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/stretchr/testify/require"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
"github.com/stretchr/testify/require"
)
func TestConfig(t *testing.T) {
@ -17,13 +15,12 @@ func TestConfig(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
// Create a build
require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Project{
Project: &vagrant_plugin_sdk.Ref_Project{
ResourceId: "foo",
},
Project: projRef,
},
Name: "foo",
@ -34,7 +31,7 @@ func TestConfig(t *testing.T) {
// Get it exactly
vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{
Scope: &vagrant_server.ConfigGetRequest_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ResourceId: "foo"},
Project: projRef,
},
Prefix: "foo",
@ -47,7 +44,7 @@ func TestConfig(t *testing.T) {
// Get it via a prefix match
vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{
Scope: &vagrant_server.ConfigGetRequest_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ResourceId: "foo"},
Project: projRef,
},
Prefix: "",
@ -60,7 +57,7 @@ func TestConfig(t *testing.T) {
// non-matching prefix
vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{
Scope: &vagrant_server.ConfigGetRequest_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ResourceId: "foo"},
Project: projRef,
},
Prefix: "bar",
@ -76,13 +73,13 @@ func TestConfig(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
// Create a build
require.NoError(s.ConfigSet(
&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Project{
Project: &vagrant_plugin_sdk.Ref_Project{
ResourceId: "foo",
},
Project: projRef,
},
Name: "global",
@ -90,9 +87,7 @@ func TestConfig(t *testing.T) {
},
&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Project{
Project: &vagrant_plugin_sdk.Ref_Project{
ResourceId: "foo",
},
Project: projRef,
},
Name: "hello",
@ -104,9 +99,7 @@ func TestConfig(t *testing.T) {
// Get our merged variables
vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{
Scope: &vagrant_server.ConfigGetRequest_Project{
Project: &vagrant_plugin_sdk.Ref_Project{
ResourceId: "foo",
},
Project: projRef,
},
})
require.NoError(err)
@ -122,9 +115,7 @@ func TestConfig(t *testing.T) {
// Get project scoped variables. This should return everything.
vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{
Scope: &vagrant_server.ConfigGetRequest_Project{
Project: &vagrant_plugin_sdk.Ref_Project{
ResourceId: "foo",
},
Project: projRef,
},
})
require.NoError(err)
@ -138,12 +129,12 @@ func TestConfig(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
// Create a var
require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Project{
Project: &vagrant_plugin_sdk.Ref_Project{
ResourceId: "foo",
},
Project: projRef,
},
Name: "foo",
@ -154,7 +145,7 @@ func TestConfig(t *testing.T) {
// Get it exactly
vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{
Scope: &vagrant_server.ConfigGetRequest_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ResourceId: "foo"},
Project: projRef,
},
Prefix: "foo",
@ -166,9 +157,7 @@ func TestConfig(t *testing.T) {
// Delete it
require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Project{
Project: &vagrant_plugin_sdk.Ref_Project{
ResourceId: "foo",
},
Project: projRef,
},
Name: "foo",
@ -179,7 +168,7 @@ func TestConfig(t *testing.T) {
// Get it exactly
vs, err := s.ConfigGet(&vagrant_server.ConfigGetRequest{
Scope: &vagrant_server.ConfigGetRequest_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ResourceId: "foo"},
Project: projRef,
},
Prefix: "foo",
@ -195,6 +184,8 @@ func TestConfig(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
// Create the config
require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Runner{
@ -212,9 +203,7 @@ func TestConfig(t *testing.T) {
// Create a var that shouldn't match
require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Project{
Project: &vagrant_plugin_sdk.Ref_Project{
ResourceId: "foo",
},
Project: projRef,
},
Name: "bar",
@ -267,6 +256,8 @@ func TestConfig(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
// Create the config
require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Runner{
@ -286,9 +277,7 @@ func TestConfig(t *testing.T) {
// Create a var that shouldn't match
require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Project{
Project: &vagrant_plugin_sdk.Ref_Project{
ResourceId: "foo",
},
Project: projRef,
},
Name: "bar",
@ -380,12 +369,14 @@ func TestConfigWatch(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
ws := memdb.NewWatchSet()
// Get it with watch
vs, err := s.ConfigGetWatch(&vagrant_server.ConfigGetRequest{
Scope: &vagrant_server.ConfigGetRequest_Project{
Project: &vagrant_plugin_sdk.Ref_Project{ResourceId: "foo"},
Project: projRef,
},
Prefix: "foo",
@ -399,9 +390,7 @@ func TestConfigWatch(t *testing.T) {
// Create a config
require.NoError(s.ConfigSet(&vagrant_server.ConfigVar{
Scope: &vagrant_server.ConfigVar_Project{
Project: &vagrant_plugin_sdk.Ref_Project{
ResourceId: "foo",
},
Project: projRef,
},
Name: "foo",

View File

@ -0,0 +1,808 @@
package state
import (
"fmt"
"reflect"
"time"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
"github.com/mitchellh/mapstructure"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
type Decoder struct {
*mapstructure.Decoder
}
func NewDecoder(config *mapstructure.DecoderConfig) (*Decoder, error) {
intD, err := mapstructure.NewDecoder(config)
if err != nil {
return nil, err
}
return &Decoder{intD}, nil
}
func (d *Decoder) SoftDecode(input interface{}) error {
v := reflect.Indirect(reflect.ValueOf(input))
t := v.Type()
if v.Kind() != reflect.Struct {
return d.Decode(input)
}
newFields := []reflect.StructField{}
for i := 0; i < t.NumField(); i++ {
structField := t.Field(i)
if !structField.IsExported() {
continue
}
indirectVal := reflect.Indirect(v.FieldByName(structField.Name))
if !indirectVal.IsValid() {
continue
}
val := indirectVal.Interface()
fieldType := structField.Type
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}
defaultZero := reflect.Zero(fieldType).Interface()
if reflect.DeepEqual(val, defaultZero) {
continue
}
newField := reflect.StructField{
Name: structField.Name,
PkgPath: structField.PkgPath,
Tag: structField.Tag,
Type: structField.Type,
}
newFields = append(newFields, newField)
}
newStruct := reflect.StructOf(newFields)
newInput := reflect.New(newStruct).Elem()
err := mapstructure.Decode(input, newInput.Addr().Interface())
if err != nil {
// This should not happen, but if it does, we need to bail
panic("failed to generate decode copy: " + err.Error())
}
// pval := v.FieldByName("Project").Interface()
// nval := newInput.FieldByName("Project").Interface()
// e := d.Decode(newInput.Interface())
// if e != nil {
// panic(e)
// }
// nval2 := newInput.FieldByName("Project").Interface()
// pval2 := v.FieldByName("Project").Interface()
// panic(fmt.Sprintf("\n\nold value: %#v\nnew value: %#v\n\n --- \nold value: %#v\nnew value: %#v\n",
// pval, nval, pval2, nval2))
return d.Decode(newInput.Interface())
}
// Creates a decoder with all our custom hooks. This decoder can
// be used for converting models to protobuf messages and converting
// protobuf messages to models.
func (s *State) decoder(output interface{}) *Decoder {
config := mapstructure.DecoderConfig{
DecodeHook: mapstructure.ComposeDecodeHookFunc(
projectToProtoRefHookFunc,
projectToProtoHookFunc,
s.projectFromProtoHookFunc,
s.projectFromProtoRefHookFunc,
basisToProtoHookFunc,
basisToProtoRefHookFunc,
s.basisFromProtoHookFunc,
s.basisFromProtoRefHookFunc,
targetToProtoHookFunc,
targetToProtoRefHookFunc,
s.targetFromProtoHookFunc,
s.targetFromProtoRefHookFunc,
vagrantfileToProtoHookFunc,
s.vagrantfileFromProtoHookFunc,
runnerToProtoHookFunc,
s.runnerFromProtoHookFunc,
protobufToProtoValueHookFunc,
protobufToProtoRawHookFunc,
boxToProtoHookFunc,
boxToProtoRefHookFunc,
s.boxFromProtoHookFunc,
s.boxFromProtoRefHookFunc,
timeToProtoHookFunc,
timeFromProtoHookFunc,
s.scopeFromProtoHookFunc,
scopeToProtoHookFunc,
protoValueToProtoHookFunc,
protoRawToProtoHookFunc,
s.componentFromProtoHookFunc,
componentToProtoHookFunc,
),
Result: output,
}
d, err := NewDecoder(&config)
if err != nil {
panic("failed to create mapstructure decoder: " + err.Error())
}
return d
}
// Decodes input into output structure using custom decoder
func (s *State) decode(input, output interface{}) error {
return s.decoder(output).Decode(input)
}
func (s *State) softDecode(input, output interface{}) error {
return s.decoder(output).SoftDecode(input)
}
// Creates a decoder with some of our custom hooks. This can be used
// for converting models to protobuf messages but cannot be used for
// converting protobuf messages to models.
func decoder(output interface{}) *Decoder {
config := mapstructure.DecoderConfig{
DecodeHook: mapstructure.ComposeDecodeHookFunc(
projectToProtoRefHookFunc,
projectToProtoHookFunc,
basisToProtoHookFunc,
basisToProtoRefHookFunc,
targetToProtoHookFunc,
targetToProtoRefHookFunc,
vagrantfileToProtoHookFunc,
runnerToProtoHookFunc,
protobufToProtoValueHookFunc,
protobufToProtoRawHookFunc,
boxToProtoHookFunc,
boxToProtoRefHookFunc,
timeToProtoHookFunc,
timeFromProtoHookFunc,
scopeToProtoHookFunc,
protoValueToProtoHookFunc,
protoRawToProtoHookFunc,
componentToProtoHookFunc,
),
Result: output,
}
d, err := NewDecoder(&config)
if err != nil {
panic("failed to create mapstructure decoder: " + err.Error())
}
return d
}
// Decodes input into output structure using custom decoder
func decode(input, output interface{}) error {
return decoder(output).Decode(input)
}
func softDecode(input, output interface{}) error {
return decoder(output).SoftDecode(input)
}
// Everything below here are converters
func projectToProtoRefHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Project)(nil)) ||
to != reflect.TypeOf((*vagrant_plugin_sdk.Ref_Project)(nil)) {
return data, nil
}
p, ok := data.(*Project)
if !ok {
return nil, fmt.Errorf("cannot serialize project ref, wrong type (%T)", data)
}
return p.ToProtoRef(), nil
}
func projectToProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Project)(nil)) ||
to != reflect.TypeOf((*vagrant_server.Project)(nil)) {
return data, nil
}
p, ok := data.(*Project)
if !ok {
return nil, fmt.Errorf("cannot serialize project, wrong type (%T)", data)
}
return p.ToProto(), nil
}
func (s *State) projectFromProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_server.Project)(nil)) ||
to != reflect.TypeOf((*Project)(nil)) {
return data, nil
}
p, ok := data.(*vagrant_server.Project)
if !ok {
return nil, fmt.Errorf("cannot deserialize project, wrong type (%T)", data)
}
return s.ProjectFromProto(p)
}
func (s *State) projectFromProtoRefHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_plugin_sdk.Ref_Project)(nil)) ||
to != reflect.TypeOf((*Project)(nil)) {
return data, nil
}
p, ok := data.(*vagrant_plugin_sdk.Ref_Project)
if !ok {
return nil, fmt.Errorf("cannot deserialize project ref, wrong type (%T)", data)
}
return s.ProjectFromProtoRef(p)
}
func basisToProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Basis)(nil)) ||
to != reflect.TypeOf((*vagrant_server.Basis)(nil)) {
return data, nil
}
b, ok := data.(*Basis)
if !ok {
return nil, fmt.Errorf("cannot serialize basis, wrong type (%T)", data)
}
return b.ToProto(), nil
}
func basisToProtoRefHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Basis)(nil)) ||
to != reflect.TypeOf((*vagrant_plugin_sdk.Ref_Basis)(nil)) {
return data, nil
}
b, ok := data.(*Basis)
if !ok {
return nil, fmt.Errorf("cannot serialize basis ref, wrong type (%T)", data)
}
return b.ToProtoRef(), nil
}
func (s *State) basisFromProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_server.Basis)(nil)) ||
to != reflect.TypeOf((*Basis)(nil)) {
return data, nil
}
b, ok := data.(*vagrant_server.Basis)
if !ok {
return nil, fmt.Errorf("cannot deserialize basis, wrong type (%T)", data)
}
return s.BasisFromProto(b)
}
func (s *State) basisFromProtoRefHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_plugin_sdk.Ref_Basis)(nil)) ||
to != reflect.TypeOf((*Basis)(nil)) {
return data, nil
}
b, ok := data.(*vagrant_plugin_sdk.Ref_Basis)
if !ok {
return nil, fmt.Errorf("cannot deserialize basis ref, wrong type (%T)", data)
}
return s.BasisFromProtoRef(b)
}
func targetToProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Target)(nil)) ||
to != reflect.TypeOf((*vagrant_server.Target)(nil)) {
return data, nil
}
t, ok := data.(*Target)
if !ok {
return nil, fmt.Errorf("cannot serialize target, wrong type (%T)", data)
}
return t.ToProto(), nil
}
func targetToProtoRefHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Target)(nil)) ||
to != reflect.TypeOf((*vagrant_plugin_sdk.Ref_Target)(nil)) {
return data, nil
}
t, ok := data.(*Target)
if !ok {
return nil, fmt.Errorf("cannot serialize target ref, wrong type (%T)", data)
}
return t.ToProtoRef(), nil
}
func (s *State) targetFromProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_server.Target)(nil)) ||
to != reflect.TypeOf((*Target)(nil)) {
return data, nil
}
t, ok := data.(*vagrant_server.Target)
if !ok {
return nil, fmt.Errorf("cannot deserialize target, wrong type (%T)", data)
}
return s.TargetFromProto(t)
}
func (s *State) targetFromProtoRefHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_plugin_sdk.Ref_Target)(nil)) ||
to != reflect.TypeOf((*Target)(nil)) {
return data, nil
}
t, ok := data.(*vagrant_plugin_sdk.Ref_Target)
if !ok {
return nil, fmt.Errorf("cannot deserialize target ref, wrong type (%T)", data)
}
return s.TargetFromProtoRef(t)
}
func vagrantfileToProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Vagrantfile)(nil)) ||
to != reflect.TypeOf((*vagrant_server.Vagrantfile)(nil)) {
return data, nil
}
v, ok := data.(*Vagrantfile)
if !ok {
return nil, fmt.Errorf("cannot serialize vagrantfile, wrong type (%T)", data)
}
return v.ToProto(), nil
}
func (s *State) vagrantfileFromProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_server.Vagrantfile)(nil)) ||
to != reflect.TypeOf((*Vagrantfile)(nil)) {
return data, nil
}
v, ok := data.(*vagrant_server.Vagrantfile)
if !ok {
return nil, fmt.Errorf("cannot deserialize vagrantfile, wrong type (%T)", data)
}
return s.VagrantfileFromProto(v), nil
}
func runnerToProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Runner)(nil)) ||
to != reflect.TypeOf((*vagrant_server.Runner)(nil)) {
return data, nil
}
r, ok := data.(*Runner)
if !ok {
return nil, fmt.Errorf("cannot serialize runner, wrong type (%T)", data)
}
return r.ToProto(), nil
}
func (s *State) runnerFromProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_server.Runner)(nil)) ||
to != reflect.TypeOf((*Runner)(nil)) {
return data, nil
}
r, ok := data.(*vagrant_server.Runner)
if !ok {
return nil, fmt.Errorf("cannot deserialize runner, wrong type (%T)", data)
}
return s.RunnerFromProto(r)
}
func protobufToProtoValueHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if to != reflect.TypeOf((*ProtoValue)(nil)) {
return data, nil
}
p, ok := data.(proto.Message)
if ok {
return &ProtoValue{Message: p}, nil
}
switch v := data.(type) {
case *vagrant_server.Job_Init:
return &ProtoValue{Message: v.Init}, nil
case *vagrant_server.Job_Command:
return &ProtoValue{Message: v.Command}, nil
case *vagrant_server.Job_Noop_:
return &ProtoValue{Message: v.Noop}, nil
}
return data, nil
}
func protobufToProtoRawHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if !from.Implements(reflect.TypeOf((*proto.Message)(nil)).Elem()) ||
to != reflect.TypeOf((*ProtoRaw)(nil)) {
return data, nil
}
p, ok := data.(proto.Message)
if !ok {
return nil, fmt.Errorf("cannot wrap into protovalue, wrong type (%T)", data)
}
return &ProtoRaw{Message: p}, nil
}
func boxToProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Box)(nil)) ||
to != reflect.TypeOf((*vagrant_server.Box)(nil)) {
return data, nil
}
b, ok := data.(*Box)
if !ok {
return nil, fmt.Errorf("cannot serialize box, wrong type (%T)", data)
}
return b.ToProto(), nil
}
func boxToProtoRefHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Box)(nil)) ||
to != reflect.TypeOf((*vagrant_plugin_sdk.Ref_Box)(nil)) {
return data, nil
}
b, ok := data.(*Box)
if !ok {
return nil, fmt.Errorf("cannot serialize box ref, wrong type (%T)", data)
}
return b.ToProtoRef(), nil
}
func (s *State) boxFromProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_server.Box)(nil)) ||
to != reflect.TypeOf((*Box)(nil)) {
return data, nil
}
b, ok := data.(*vagrant_server.Box)
if !ok {
return nil, fmt.Errorf("cannot deserialize box, wrong type (%T)", data)
}
return s.BoxFromProto(b)
}
func (s *State) boxFromProtoRefHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_plugin_sdk.Ref_Box)(nil)) ||
to != reflect.TypeOf((*Box)(nil)) {
return data, nil
}
b, ok := data.(*vagrant_plugin_sdk.Ref_Box)
if !ok {
return nil, fmt.Errorf("cannot deserialize box ref, wrong type (%T)", data)
}
return s.BoxFromProtoRef(b)
}
func timeToProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*time.Time)(nil)) ||
to != reflect.TypeOf((*timestamppb.Timestamp)(nil)) {
return data, nil
}
t, ok := data.(*time.Time)
if !ok {
return nil, fmt.Errorf("cannot serialize time, wrong type (%T)", data)
}
return timestamppb.New(*t), nil
}
func timeFromProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*timestamppb.Timestamp)(nil)) ||
to != reflect.TypeOf((*time.Time)(nil)) {
return data, nil
}
t, ok := data.(*timestamppb.Timestamp)
if !ok {
return nil, fmt.Errorf("cannot deserialize time, wrong type (%T)", data)
}
at := t.AsTime()
return &at, nil
}
func protoValueToProtoHookFunc(
from, to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*ProtoValue)(nil)) {
return data, nil
}
p, ok := data.(*ProtoValue)
if !ok {
return nil, fmt.Errorf("invalid proto value (%s -> %s)", from, to)
}
if p.Message == nil {
return nil, fmt.Errorf("proto value contents is nil (destination: %s)", to)
}
if reflect.ValueOf(p.Message).Type().AssignableTo(to) {
return p.Message, nil
}
switch v := p.Message.(type) {
// Start with Job oneof types
case *vagrant_server.Job_InitOp:
if reflect.TypeOf((*vagrant_server.Job_Init)(nil)).AssignableTo(to) {
return &vagrant_server.Job_Init{Init: v}, nil
}
case *vagrant_server.Job_CommandOp:
if reflect.TypeOf((*vagrant_server.Job_Command)(nil)).AssignableTo(to) {
return &vagrant_server.Job_Command{Command: v}, nil
}
case *vagrant_server.Job_Noop:
if reflect.TypeOf((*vagrant_server.Job_Noop_)(nil)).AssignableTo(to) {
return &vagrant_server.Job_Noop_{Noop: v}, nil
}
}
return data, nil
}
func protoRawToProtoHookFunc(
from, to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*ProtoRaw)(nil)) {
return data, nil
}
p, ok := data.(*ProtoRaw)
if !ok {
return nil, fmt.Errorf("invalid proto value (%s -> %s)", from, to)
}
if !reflect.ValueOf(p.Message).Type().AssignableTo(to) {
return data, nil
}
return p.Message, nil
}
func (s *State) scopeFromProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if (from != reflect.TypeOf((*vagrant_server.Job_Basis)(nil)) &&
from != reflect.TypeOf((*vagrant_server.Job_Project)(nil)) &&
from != reflect.TypeOf((*vagrant_server.Job_Target)(nil)) &&
from != reflect.TypeOf((*vagrant_server.ConfigVar_Basis)(nil)) &&
from != reflect.TypeOf((*vagrant_server.ConfigVar_Project)(nil)) &&
from != reflect.TypeOf((*vagrant_server.ConfigVar_Target)(nil))) ||
!to.Implements(reflect.TypeOf((*scope)(nil)).Elem()) {
return data, nil
}
var result scope
var err error
switch v := data.(type) {
case *vagrant_server.Job_Basis:
result, err = s.BasisFromProtoRef(v.Basis)
case *vagrant_server.ConfigVar_Basis:
result, err = s.BasisFromProtoRef(v.Basis)
case *vagrant_server.Job_Project:
result, err = s.ProjectFromProtoRef(v.Project)
case *vagrant_server.ConfigVar_Project:
result, err = s.ProjectFromProtoRef(v.Project)
case *vagrant_server.Job_Target:
result, err = s.TargetFromProtoRef(v.Target)
case *vagrant_server.ConfigVar_Target:
result, err = s.TargetFromProtoRef(v.Target)
default:
err = fmt.Errorf("invalid job scope type (%T)", data)
}
if err != nil {
return nil, err
}
return result, nil
}
func scopeToProtoHookFunc(
from reflect.Type,
to reflect.Type,
data interface{},
) (interface{}, error) {
if !from.Implements(reflect.TypeOf((*scope)(nil)).Elem()) {
return data, nil
}
switch v := data.(type) {
case *Basis:
if reflect.TypeOf((*vagrant_server.Job_Basis)(nil)).AssignableTo(to) {
return &vagrant_server.Job_Basis{Basis: v.ToProtoRef()}, nil
}
if reflect.TypeOf((*vagrant_server.ConfigVar_Basis)(nil)).AssignableTo(to) {
return &vagrant_server.ConfigVar_Basis{Basis: v.ToProtoRef()}, nil
}
case *Project:
if reflect.TypeOf((*vagrant_server.Job_Project)(nil)).AssignableTo(to) {
return &vagrant_server.Job_Project{Project: v.ToProtoRef()}, nil
}
if reflect.TypeOf((*vagrant_server.ConfigVar_Project)(nil)).AssignableTo(to) {
return &vagrant_server.ConfigVar_Project{Project: v.ToProtoRef()}, nil
}
case *Target:
if reflect.TypeOf((*vagrant_server.Job_Target)(nil)).AssignableTo(to) {
return &vagrant_server.Job_Target{Target: v.ToProtoRef()}, nil
}
if reflect.TypeOf((*vagrant_server.ConfigVar_Target)(nil)).AssignableTo(to) {
return &vagrant_server.ConfigVar_Target{Target: v.ToProtoRef()}, nil
}
}
return data, nil
}
func (s *State) componentFromProtoHookFunc(
from, to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*vagrant_server.Component)(nil)) ||
to != reflect.TypeOf((*Component)(nil)) {
return data, nil
}
c, ok := data.(*vagrant_server.Component)
if !ok {
return nil, fmt.Errorf("cannot deserialize component, wrong type (%T)", data)
}
return s.ComponentFromProto(c)
}
func componentToProtoHookFunc(
from, to reflect.Type,
data interface{},
) (interface{}, error) {
if from != reflect.TypeOf((*Component)(nil)) ||
to != reflect.TypeOf((*vagrant_server.Component)(nil)) {
return data, nil
}
c, ok := data.(*Component)
if !ok {
return nil, fmt.Errorf("cannot serialize component, wrong type (%T)", data)
}
return c.ToProto(), nil
}

View File

@ -0,0 +1,40 @@
package state
import (
"testing"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
"github.com/stretchr/testify/require"
)
func TestSoftDecode(t *testing.T) {
t.Run("Decodes nothing when unset", func(t *testing.T) {
require := require.New(t)
s := TestState(t)
defer s.Close()
tref := &vagrant_plugin_sdk.Target{}
var target Target
err := s.softDecode(tref, &target)
require.NoError(err)
require.Equal(target, Target{})
})
t.Run("Decodes project reference", func(t *testing.T) {
require := require.New(t)
s := TestState(t)
defer s.Close()
pref := testProject(t, s)
tproto := &vagrant_server.Target{
Project: pref,
}
var target Target
err := s.softDecode(tproto, &target)
require.NoError(err)
require.NotNil(target.Project)
require.Equal(*target.Project.ResourceId, tproto.Project.ResourceId)
})
}

View File

@ -2,26 +2,32 @@ package state
import (
"context"
"errors"
"fmt"
"reflect"
"sort"
"time"
"github.com/hashicorp/go-memdb"
bolt "go.etcd.io/bbolt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server"
"github.com/hashicorp/vagrant/internal/server/logbuffer"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
)
func init() {
models = append(models, &InternalJob{})
dbIndexers = append(dbIndexers, (*State).jobIndexInit)
schemas = append(schemas, jobSchema)
}
var (
jobBucket = []byte("jobs")
jobWaitingTimeout = 2 * time.Minute
jobHeartbeatTimeout = 2 * time.Minute
)
@ -35,10 +41,148 @@ const (
maximumJobsInMem = 10000
)
func init() {
dbBuckets = append(dbBuckets, jobBucket)
dbIndexers = append(dbIndexers, (*State).jobIndexInit)
schemas = append(schemas, jobSchema)
type JobState uint8
const (
JOB_STATE_UNKNOWN JobState = JobState(vagrant_server.Job_UNKNOWN)
JOB_STATE_QUEUED = JobState(vagrant_server.Job_QUEUED)
JOB_STATE_WAITING = JobState(vagrant_server.Job_WAITING)
JOB_STATE_RUNNING = JobState(vagrant_server.Job_RUNNING)
JOB_STATE_ERROR = JobState(vagrant_server.Job_ERROR)
JOB_STATE_SUCCESS = JobState(vagrant_server.Job_SUCCESS)
)
type InternalJob struct {
gorm.Model
AssignTime *time.Time
AckTime *time.Time
AssignedRunnerID uint `mapstructure:"-"`
AssignedRunner *Runner
CancelTime *time.Time
CompleteTime *time.Time
DataSource *ProtoValue
DataSourceOverrides MetadataSet
Error *ProtoValue
ExpireTime *time.Time
Labels MetadataSet
Jid *string `gorm:"uniqueIndex" mapstructure:"Id"`
Operation *ProtoValue `mapstructure:"Operation"`
QueueTime *time.Time
Result *ProtoValue
Scope scope `gorm:"-:all"`
ScopeID uint `mapstructure:"-"`
ScopeType string `mapstructure:"-"`
State JobState
TargetRunner *ProtoValue
}
// Job should come with an ID assigned to it, but if it doesn't for
// some reason, we assign one now.
func (i *InternalJob) BeforeCreate(tx *gorm.DB) error {
if i.Jid == nil {
id, err := server.Id()
if err != nil {
return err
}
i.Jid = &id
}
return nil
}
// If the job has a scope assigned to it, persist it.
func (i *InternalJob) BeforeSave(tx *gorm.DB) (err error) {
if i.Scope == nil {
i.ScopeID = 0
i.ScopeType = ""
return nil
}
switch v := i.Scope.(type) {
case *Basis:
i.ScopeID = v.ID
i.ScopeType = "basis"
case *Project:
i.ScopeID = v.ID
i.ScopeType = "project"
case *Target:
i.ScopeID = v.ID
i.ScopeType = "target"
default:
return fmt.Errorf("unknown scope type (%T)", i.Scope)
}
return nil
}
// If the job has a scope, load it.
func (i *InternalJob) AfterFind(tx *gorm.DB) (err error) {
if i.ScopeID == 0 {
return nil
}
switch i.ScopeType {
case "basis":
var b Basis
result := tx.Preload(clause.Associations).
First(&b, &Basis{Model: gorm.Model{ID: i.ScopeID}})
if result.Error != nil {
return result.Error
}
i.Scope = &b
case "project":
var p Project
result := tx.Preload(clause.Associations).
First(&p, &Project{Model: gorm.Model{ID: i.ScopeID}})
if result.Error != nil {
return result.Error
}
i.Scope = &p
case "target":
var t Target
result := tx.Preload(clause.Associations).
First(&t, &Target{Model: gorm.Model{ID: i.ScopeID}})
if result.Error != nil {
return result.Error
}
i.Scope = &t
default:
return fmt.Errorf("unknown scope type (%s)", i.ScopeType)
}
return nil
}
// Convert job to a protobuf message
func (i *InternalJob) ToProto() *vagrant_server.Job {
if i == nil {
return nil
}
var j vagrant_server.Job
err := decode(i, &j)
if err != nil {
panic("failed to decode job: " + err.Error())
}
return &j
}
func (s *State) InternalJobFromProto(job *vagrant_server.Job) (*InternalJob, error) {
if job == nil {
return nil, ErrEmptyProtoArgument
}
if job.Id == "" {
return nil, gorm.ErrRecordNotFound
}
var j InternalJob
result := s.search().First(&j, &InternalJob{Jid: &job.Id})
if result.Error != nil {
return nil, result.Error
}
return &j, nil
}
func jobSchema() *memdb.TableSchema {
@ -115,9 +259,9 @@ type jobIndex struct {
// The basis/project/machine that this job is part of. This is used
// to determine if the job is blocked. See job_assigned.go for more details.
Basis *vagrant_plugin_sdk.Ref_Basis
Project *vagrant_plugin_sdk.Ref_Project
Target *vagrant_plugin_sdk.Ref_Target
Scope interface {
GetResourceId() string
}
// QueueTime is the time that the job was queued.
QueueTime time.Time
@ -171,14 +315,16 @@ func (s *State) JobCreate(jobpb *vagrant_server.Job) error {
txn := s.inmem.Txn(true)
defer txn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) error {
return s.jobCreate(dbTxn, txn, jobpb)
})
err := s.jobCreate(txn, jobpb)
if err == nil {
txn.Commit()
}
return err
if err != nil {
return lookupErrorToStatus("job", err)
}
return nil
}
// JobList returns the list of jobs.
@ -188,7 +334,7 @@ func (s *State) JobList() ([]*vagrant_server.Job, error) {
iter, err := memTxn.Get(jobTableName, jobIdIndexName+"_prefix", "")
if err != nil {
return nil, err
return nil, lookupErrorToStatus("job", err)
}
var result []*vagrant_server.Job
@ -199,11 +345,10 @@ func (s *State) JobList() ([]*vagrant_server.Job, error) {
}
idx := next.(*jobIndex)
var job *vagrant_server.Job
err = s.db.View(func(dbTxn *bolt.Tx) error {
job, err = s.jobById(dbTxn, idx.Id)
return err
})
job, err := s.jobById(idx.Id)
if err != nil {
return nil, lookupErrorToStatus("job", err)
}
result = append(result, job)
}
@ -220,7 +365,7 @@ func (s *State) JobById(id string, ws memdb.WatchSet) (*Job, error) {
watchCh, raw, err := memTxn.FirstWatch(jobTableName, jobIdIndexName, id)
if err != nil {
return nil, err
return nil, lookupErrorToStatus("job", err)
}
ws.Add(watchCh)
@ -235,20 +380,19 @@ func (s *State) JobById(id string, ws memdb.WatchSet) (*Job, error) {
if jobIdx.State == vagrant_server.Job_QUEUED {
blocked, err = s.jobIsBlocked(memTxn, jobIdx, ws)
if err != nil {
return nil, err
return nil, lookupErrorToStatus("job", err)
}
}
var job *vagrant_server.Job
err = s.db.View(func(dbTxn *bolt.Tx) error {
job, err = s.jobById(dbTxn, jobIdx.Id)
return err
})
job, err := s.jobById(jobIdx.Id)
if err != nil {
return nil, lookupErrorToStatus("job", err)
}
result := jobIdx.Job(job)
result.Blocked = blocked
return result, err
return result, nil
}
// JobAssignForRunner will wait for and assign a job to a specific runner.
@ -266,10 +410,13 @@ RETRY_ASSIGN:
defer txn.Abort()
// Turn our runner into a runner record so we can more efficiently assign
runnerRec := newRunnerRecord(r)
runnerRec, err := s.RunnerFromProto(r)
if err != nil {
return nil, fmt.Errorf("runner lookup failed: %w", err)
}
// candidateQuery finds candidate jobs to assign.
type candidateFunc func(*memdb.Txn, memdb.WatchSet, *runnerRecord) (*jobIndex, error)
type candidateFunc func(*memdb.Txn, memdb.WatchSet, *Runner) (*jobIndex, error)
candidateQuery := []candidateFunc{
s.jobCandidateById,
s.jobCandidateAny,
@ -409,7 +556,7 @@ func (s *State) JobAck(id string, ack bool) (*Job, error) {
// Get the job
raw, err := txn.First(jobTableName, jobIdIndexName, id)
if err != nil {
return nil, err
return nil, lookupErrorToStatus("job", err)
}
if raw == nil {
return nil, status.Errorf(codes.NotFound, "job not found: %s", id)
@ -443,7 +590,7 @@ func (s *State) JobAck(id string, ack bool) (*Job, error) {
return nil
})
if err != nil {
return nil, err
return nil, lookupErrorToStatus("job", err)
}
// Cancel our timer
@ -467,13 +614,13 @@ func (s *State) JobAck(id string, ack bool) (*Job, error) {
// Insert to update
if err := txn.Insert(jobTableName, job); err != nil {
return nil, err
return nil, saveErrorToStatus("job", err)
}
// Update our assigned state if we nacked
if !ack {
if err := s.jobAssignedSet(txn, job, false); err != nil {
return nil, err
return nil, saveErrorToStatus("job", err)
}
}
@ -491,7 +638,7 @@ func (s *State) JobComplete(id string, result *vagrant_server.Job_Result, cerr e
// Get the job
raw, err := txn.First(jobTableName, jobIdIndexName, id)
if err != nil {
return err
return lookupErrorToStatus("job", err)
}
if raw == nil {
return status.Errorf(codes.NotFound, "job not found: %s", id)
@ -500,7 +647,7 @@ func (s *State) JobComplete(id string, result *vagrant_server.Job_Result, cerr e
// Update our assigned state
if err := s.jobAssignedSet(txn, job, false); err != nil {
return err
return saveErrorToStatus("job", err)
}
// If the job is not in the assigned state, then this is an error.
@ -528,7 +675,7 @@ func (s *State) JobComplete(id string, result *vagrant_server.Job_Result, cerr e
return nil
})
if err != nil {
return err
return saveErrorToStatus("job", err)
}
// End the job
@ -536,7 +683,7 @@ func (s *State) JobComplete(id string, result *vagrant_server.Job_Result, cerr e
// Insert to update
if err := txn.Insert(jobTableName, job); err != nil {
return err
return saveErrorToStatus("job", err)
}
txn.Commit()
@ -553,7 +700,7 @@ func (s *State) JobCancel(id string, force bool) error {
// Get the job
raw, err := txn.First(jobTableName, jobIdIndexName, id)
if err != nil {
return err
return lookupErrorToStatus("job", err)
}
if raw == nil {
return status.Errorf(codes.NotFound, "job not found: %s", id)
@ -561,7 +708,7 @@ func (s *State) JobCancel(id string, force bool) error {
job := raw.(*jobIndex)
if err := s.jobCancel(txn, job, force); err != nil {
return err
return saveErrorToStatus("job", err)
}
txn.Commit()
@ -717,11 +864,8 @@ func (s *State) JobExpire(id string) error {
// deregister between this returning true and queueing, the job may still
// sit in a queue indefinitely.
func (s *State) JobIsAssignable(ctx context.Context, jobpb *vagrant_server.Job) (bool, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
// If we have no runners, we cannot be assigned
empty, err := s.runnerEmpty(memTxn)
empty, err := s.runnerEmpty()
if err != nil {
return false, err
}
@ -730,74 +874,60 @@ func (s *State) JobIsAssignable(ctx context.Context, jobpb *vagrant_server.Job)
}
// If we have a special targeting constraint, that has to be met
var iter memdb.ResultIterator
var targetCheck func(*vagrant_server.Runner) (bool, error)
tx := s.db.Model(&Runner{})
switch v := jobpb.TargetRunner.Target.(type) {
case *vagrant_server.Ref_Runner_Any:
// We need a special target check that disallows by ID only
targetCheck = func(r *vagrant_server.Runner) (bool, error) {
return !r.ByIdOnly, nil
}
iter, err = memTxn.LowerBound(runnerTableName, runnerIdIndexName, "")
tx = tx.Where("by_id_only = ?", false)
case *vagrant_server.Ref_Runner_Id:
iter, err = memTxn.Get(runnerTableName, runnerIdIndexName, v.Id.Id)
tx = tx.Where("rid = ?", v.Id.Id)
default:
return false, fmt.Errorf("unknown runner target value: %#v", jobpb.TargetRunner.Target)
}
if err != nil {
return false, err
var c int64
result := tx.Count(&c)
if result.Error != nil {
return false, result.Error
}
for {
raw := iter.Next()
if raw == nil {
// We're out of candidates and we found none.
return false, nil
}
runner := raw.(*runnerRecord)
// Check our target-specific check
if targetCheck != nil {
check, err := targetCheck(runner.Runner)
if err != nil {
return false, err
}
if !check {
continue
}
}
// This works!
return true, nil
}
return c > 0, result.Error
}
// jobIndexInit initializes the config index from persisted data.
func (s *State) jobIndexInit(dbTxn *bolt.Tx, memTxn *memdb.Txn) error {
bucket := dbTxn.Bucket(jobBucket)
return bucket.ForEach(func(k, v []byte) error {
var value vagrant_server.Job
if err := proto.Unmarshal(v, &value); err != nil {
return err
func (s *State) jobIndexInit(memTxn *memdb.Txn) error {
var jobs []InternalJob
// Get all jobs which are not completed
result := s.search().
Where(&InternalJob{State: JOB_STATE_UNKNOWN}).
Or(&InternalJob{State: JOB_STATE_QUEUED}).
Or(&InternalJob{State: JOB_STATE_WAITING}).
Or(&InternalJob{State: JOB_STATE_RUNNING}).
Find(&jobs)
if result.Error != nil {
return result.Error
}
idx, err := s.jobIndexSet(memTxn, k, &value)
// Load all incomplete jobs into memory
for _, j := range jobs {
job := j.ToProto()
if j.Jid == nil {
continue
}
idx, err := s.jobIndexSet(memTxn, []byte(*j.Jid), job)
if err != nil {
return err
}
// If the job was running or waiting, set it as assigned.
if value.State == vagrant_server.Job_RUNNING || value.State == vagrant_server.Job_WAITING {
if err := s.jobAssignedSet(memTxn, idx, true); err != nil {
if j.State == JOB_STATE_WAITING || j.State == JOB_STATE_RUNNING {
if err = s.jobAssignedSet(memTxn, idx, true); err != nil {
return err
}
}
}
return nil
})
}
// jobIndexSet writes an index record for a single job.
@ -805,14 +935,20 @@ func (s *State) jobIndexSet(txn *memdb.Txn, id []byte, jobpb *vagrant_server.Job
rec := &jobIndex{
Id: jobpb.Id,
State: jobpb.State,
Basis: jobpb.Basis,
Project: jobpb.Project,
Target: jobpb.Target,
OpType: reflect.TypeOf(jobpb.Operation),
}
switch v := jobpb.Scope.(type) {
case *vagrant_server.Job_Basis:
rec.Scope = v.Basis
case *vagrant_server.Job_Project:
rec.Scope = v.Project
case *vagrant_server.Job_Target:
rec.Scope = v.Target
}
// Target
if jobpb.TargetRunner == nil {
if jobpb.TargetRunner == nil || jobpb.TargetRunner.Target == nil {
return nil, fmt.Errorf("job target runner must be set")
}
switch v := jobpb.TargetRunner.Target.(type) {
@ -823,7 +959,7 @@ func (s *State) jobIndexSet(txn *memdb.Txn, id []byte, jobpb *vagrant_server.Job
rec.TargetRunnerId = v.Id.Id
default:
return nil, fmt.Errorf("unknown runner target value: %#v", jobpb.TargetRunner.Target)
return nil, fmt.Errorf("unknown runner target value: %#v", jobpb.TargetRunner)
}
// Timestamps
@ -881,21 +1017,36 @@ func (s *State) jobIndexSet(txn *memdb.Txn, id []byte, jobpb *vagrant_server.Job
return rec, txn.Insert(jobTableName, rec)
}
func (s *State) jobCreate(dbTxn *bolt.Tx, memTxn *memdb.Txn, jobpb *vagrant_server.Job) error {
func (s *State) jobCreate(memTxn *memdb.Txn, jobpb *vagrant_server.Job) error {
// Setup our initial job state
var err error
jobpb.State = vagrant_server.Job_QUEUED
jobpb.QueueTime = timestamppb.New(time.Now())
id := []byte(jobpb.Id)
// Insert into bolt
if err := dbPut(dbTxn.Bucket(jobBucket), id, jobpb); err != nil {
// Convert the job proto into a record
job, err := s.InternalJobFromProto(jobpb)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
// Insert into the DB
_, err = s.jobIndexSet(memTxn, id, jobpb)
if err != nil {
job = &InternalJob{}
}
if err = s.softDecode(jobpb, job); err != nil {
return err
}
// Save the record into the db
result := s.db.Create(job)
if result.Error != nil {
return result.Error
}
id := []byte(*job.Jid)
// Insert into the in memory db
_, err = s.jobIndexSet(memTxn, id, job.ToProto())
s.pruneMu.Lock()
defer s.pruneMu.Unlock()
@ -904,39 +1055,52 @@ func (s *State) jobCreate(dbTxn *bolt.Tx, memTxn *memdb.Txn, jobpb *vagrant_serv
return err
}
func (s *State) jobById(dbTxn *bolt.Tx, id string) (*vagrant_server.Job, error) {
var result vagrant_server.Job
b := dbTxn.Bucket(jobBucket)
return &result, dbGet(b, []byte(id), &result)
func (s *State) jobById(sid string) (*vagrant_server.Job, error) {
job, err := s.InternalJobFromProto(&vagrant_server.Job{Id: sid})
if err != nil {
return nil, err
}
return job.ToProto(), nil
}
func (s *State) jobReadAndUpdate(id string, f func(*vagrant_server.Job) error) (*vagrant_server.Job, error) {
var result *vagrant_server.Job
var err error
return result, s.db.Update(func(dbTxn *bolt.Tx) error {
result, err = s.jobById(dbTxn, id)
j, err := s.jobById(id)
if err != nil {
return err
return nil, err
}
// Modify
if err := f(result); err != nil {
return err
if err := f(j); err != nil {
return nil, err
}
// Commit
return dbPut(dbTxn.Bucket(jobBucket), []byte(id), result)
})
ij, err := s.InternalJobFromProto(j)
if err != nil {
return nil, err
}
if err := s.softDecode(j, ij); err != nil {
return nil, err
}
result := s.db.Save(ij)
if result.Error != nil {
return nil, result.Error
}
return ij.ToProto(), nil
}
// jobCandidateById returns the most promising candidate job to assign
// that is targeting a specific runner by ID.
func (s *State) jobCandidateById(memTxn *memdb.Txn, ws memdb.WatchSet, r *runnerRecord) (*jobIndex, error) {
func (s *State) jobCandidateById(memTxn *memdb.Txn, ws memdb.WatchSet, r *Runner) (*jobIndex, error) {
iter, err := memTxn.LowerBound(
jobTableName,
jobTargetIdIndexName,
vagrant_server.Job_QUEUED,
r.Id,
*r.Rid,
time.Unix(0, 0),
)
if err != nil {
@ -950,7 +1114,7 @@ func (s *State) jobCandidateById(memTxn *memdb.Txn, ws memdb.WatchSet, r *runner
}
job := raw.(*jobIndex)
if job.State != vagrant_server.Job_QUEUED || job.TargetRunnerId != r.Id {
if job.State != vagrant_server.Job_QUEUED || job.TargetRunnerId != *r.Rid {
continue
}
@ -968,7 +1132,7 @@ func (s *State) jobCandidateById(memTxn *memdb.Txn, ws memdb.WatchSet, r *runner
}
// jobCandidateAny returns the first candidate job that targets any runner.
func (s *State) jobCandidateAny(memTxn *memdb.Txn, ws memdb.WatchSet, r *runnerRecord) (*jobIndex, error) {
func (s *State) jobCandidateAny(memTxn *memdb.Txn, ws memdb.WatchSet, r *Runner) (*jobIndex, error) {
iter, err := memTxn.LowerBound(
jobTableName,
jobQueueTimeIndexName,
@ -1020,36 +1184,20 @@ func (s *State) jobsPruneOld(memTxn *memdb.Txn, max int) (int, error) {
}
func (s *State) JobsDBPruneOld(max int) (int, error) {
cnt := dbCount(s.db, jobTableName)
toDelete := cnt - max
var deleted int
// Prune jobs from boltDB
s.db.Update(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte(jobTableName))
cur := bucket.Cursor()
key, _ := cur.First()
for {
if key == nil {
break
var jobs []InternalJob
result := s.db.Select("id").Order("queue_time asc").Offset(max).Find(&jobs)
if result.Error != nil {
return 0, result.Error
}
// otherwise, prune this job! Once we've pruned enough jobs to get back
// to the maximum, we stop pruning.
toDelete--
err := bucket.Delete(key)
if err != nil {
return err
deleted := len(jobs)
if deleted < 1 {
return deleted, nil
}
result = s.db.Unscoped().Delete(jobs)
if result.Error != nil {
return 0, result.Error
}
deleted++
if toDelete <= 0 {
break
}
key, _ = cur.Next()
}
return nil
})
return deleted, nil
}

View File

@ -59,9 +59,7 @@ func jobAssignedSchema() *memdb.TableSchema {
}
type jobAssignedIndex struct {
Basis string
Project string
Machine string
ResourceId string
}
// jobIsBlocked will return true if the given job is currently blocked because
@ -80,7 +78,7 @@ func (s *State) jobIsBlocked(memTxn *memdb.Txn, idx *jobIndex, ws memdb.WatchSet
watchCh, value, err := memTxn.FirstWatch(
jobAssignedTableName,
jobAssignedIdIndexName,
s.jobAssignedIdxArgs(idx)...,
idx.Scope.GetResourceId(),
)
if err != nil {
return false, err
@ -100,11 +98,8 @@ func (s *State) jobAssignedSet(memTxn *memdb.Txn, idx *jobIndex, assigned bool)
return nil
}
args := s.jobAssignedIdxArgs(idx)
rec := &jobAssignedIndex{
Basis: args[0].(string),
Project: args[1].(string),
Machine: args[2].(string),
ResourceId: idx.Scope.GetResourceId(),
}
if assigned {
@ -113,16 +108,3 @@ func (s *State) jobAssignedSet(memTxn *memdb.Txn, idx *jobIndex, assigned bool)
return memTxn.Delete(jobAssignedTableName, rec)
}
func (s *State) jobAssignedIdxArgs(idx *jobIndex) []interface{} {
if idx.Target != nil {
return []interface{}{
idx.Target.Project.Basis.ResourceId, idx.Target.Project.ResourceId, idx.Target.ResourceId,
}
} else if idx.Project != nil {
return []interface{}{
idx.Project.Basis.ResourceId, idx.Project.ResourceId, "",
}
}
return []interface{}{idx.Basis.ResourceId, "", ""}
}

View File

@ -11,7 +11,7 @@ import (
"google.golang.org/grpc/status"
"github.com/hashicorp/go-memdb"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
// "github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
serverptypes "github.com/hashicorp/vagrant/internal/server/ptypes"
)
@ -23,9 +23,15 @@ func TestJobAssign(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
// Assign it, we should get this build
@ -53,11 +59,14 @@ func TestJobAssign(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Project: &vagrant_plugin_sdk.Ref_Project{
ResourceId: "project1",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
@ -90,10 +99,10 @@ func TestJobAssign(t *testing.T) {
}
// Insert another job
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "B",
Project: &vagrant_plugin_sdk.Ref_Project{
ResourceId: "project2",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
@ -268,13 +277,22 @@ func TestJobAssign(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create two builds slightly apart
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
time.Sleep(1 * time.Millisecond)
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "B",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
// Assign it, we should get build A then B
@ -304,9 +322,16 @@ func TestJobAssign(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
testRunner(t, s, &vagrant_server.Runner{Id: "R_B"})
// Create a build by ID
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
TargetRunner: &vagrant_server.Ref_Runner{
Target: &vagrant_server.Ref_Runner_Id{
Id: &vagrant_server.Ref_RunnerId{
@ -316,11 +341,11 @@ func TestJobAssign(t *testing.T) {
},
})))
time.Sleep(1 * time.Millisecond)
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "B",
})))
time.Sleep(1 * time.Millisecond)
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "C",
})))
@ -354,9 +379,16 @@ func TestJobAssign(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_B"})
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build by ID
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
TargetRunner: &vagrant_server.Ref_Runner{
Target: &vagrant_server.Ref_Runner_Id{
Id: &vagrant_server.Ref_RunnerId{
@ -393,12 +425,17 @@ func TestJobAssign(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
r := &vagrant_server.Runner{Id: "R_A", ByIdOnly: true}
testRunner(t, s, r)
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
// Should block because none direct assign
@ -410,8 +447,11 @@ func TestJobAssign(t *testing.T) {
require.Equal(ctx.Err(), err)
// Create a target
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "B",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
TargetRunner: &vagrant_server.Ref_Runner{
Target: &vagrant_server.Ref_Runner_Id{
Id: &vagrant_server.Ref_RunnerId{
@ -436,9 +476,15 @@ func TestJobAck(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
// Assign it, we should get this build
@ -466,9 +512,15 @@ func TestJobAck(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
// Assign it, we should get this build
@ -501,9 +553,15 @@ func TestJobAck(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
// Assign it, we should get this build
@ -533,9 +591,15 @@ func TestJobComplete(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
// Assign it, we should get this build
@ -550,7 +614,7 @@ func TestJobComplete(t *testing.T) {
// Complete it
require.NoError(s.JobComplete(job.Id, &vagrant_server.Job_Result{
Run: &vagrant_server.Job_RunResult{},
Run: &vagrant_server.Job_CommandResult{},
}, nil))
// Verify it is changed
@ -568,9 +632,15 @@ func TestJobComplete(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
// Assign it, we should get this build
@ -606,9 +676,14 @@ func TestJobIsAssignable(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
// Create a build
result, err := s.JobIsAssignable(ctx, serverptypes.TestJobNew(t, &vagrant_server.Job{
result, err := s.JobIsAssignable(ctx, testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
}))
require.NoError(err)
require.False(result)
@ -620,13 +695,15 @@ func TestJobIsAssignable(t *testing.T) {
s := TestState(t)
defer s.Close()
// Register a runner
require.NoError(s.RunnerCreate(serverptypes.TestRunner(t, nil)))
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Should be assignable
result, err := s.JobIsAssignable(ctx, serverptypes.TestJobNew(t, &vagrant_server.Job{
result, err := s.JobIsAssignable(ctx, testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
TargetRunner: &vagrant_server.Ref_Runner{
Target: &vagrant_server.Ref_Runner_Any{
Any: &vagrant_server.Ref_RunnerAny{},
@ -643,15 +720,15 @@ func TestJobIsAssignable(t *testing.T) {
s := TestState(t)
defer s.Close()
// Register a runner
require.NoError(s.RunnerCreate(serverptypes.TestRunner(t, &vagrant_server.Runner{
ByIdOnly: true,
})))
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A", ByIdOnly: true})
// Should be assignable
result, err := s.JobIsAssignable(ctx, serverptypes.TestJobNew(t, &vagrant_server.Job{
result, err := s.JobIsAssignable(ctx, testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
TargetRunner: &vagrant_server.Ref_Runner{
Target: &vagrant_server.Ref_Runner_Any{
Any: &vagrant_server.Ref_RunnerAny{},
@ -668,13 +745,15 @@ func TestJobIsAssignable(t *testing.T) {
s := TestState(t)
defer s.Close()
// Register a runner
require.NoError(s.RunnerCreate(serverptypes.TestRunner(t, nil)))
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_B"})
// Should be assignable
result, err := s.JobIsAssignable(ctx, serverptypes.TestJobNew(t, &vagrant_server.Job{
result, err := s.JobIsAssignable(ctx, testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
TargetRunner: &vagrant_server.Ref_Runner{
Target: &vagrant_server.Ref_Runner_Id{
Id: &vagrant_server.Ref_RunnerId{
@ -693,18 +772,19 @@ func TestJobIsAssignable(t *testing.T) {
s := TestState(t)
defer s.Close()
// Register a runner
runner := serverptypes.TestRunner(t, nil)
require.NoError(s.RunnerCreate(runner))
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Should be assignable
result, err := s.JobIsAssignable(ctx, serverptypes.TestJobNew(t, &vagrant_server.Job{
result, err := s.JobIsAssignable(ctx, testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
TargetRunner: &vagrant_server.Ref_Runner{
Target: &vagrant_server.Ref_Runner_Id{
Id: &vagrant_server.Ref_RunnerId{
Id: runner.Id,
Id: "R_A",
},
},
},
@ -720,10 +800,14 @@ func TestJobCancel(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
// Cancel it
@ -742,10 +826,15 @@ func TestJobCancel(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
// Assign it, we should get this build
@ -770,10 +859,15 @@ func TestJobCancel(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
// Assign it, we should get this build
@ -798,11 +892,16 @@ func TestJobCancel(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Operation: &vagrant_server.Job_Run{},
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
Operation: &vagrant_server.Job_Command{},
})))
// Assign it, we should get this build
@ -822,9 +921,12 @@ func TestJobCancel(t *testing.T) {
require.NotEmpty(job.CancelTime)
// Create a another job
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "B",
Operation: &vagrant_server.Job_Run{},
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
Operation: &vagrant_server.Job_Command{},
})))
ws := memdb.NewWatchSet()
@ -843,10 +945,15 @@ func TestJobCancel(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
// Assign it, we should get this build
@ -880,6 +987,8 @@ func TestJobHeartbeat(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Set a short timeout
old := jobHeartbeatTimeout
@ -887,8 +996,11 @@ func TestJobHeartbeat(t *testing.T) {
jobHeartbeatTimeout = 5 * time.Millisecond
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
// Assign it, we should get this build
@ -902,6 +1014,8 @@ func TestJobHeartbeat(t *testing.T) {
_, err = s.JobAck(job.Id, true)
require.NoError(err)
time.Sleep(1 * time.Second)
// Should time out
require.Eventually(func() bool {
// Verify it is canceled
@ -921,10 +1035,15 @@ func TestJobHeartbeat(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
// Assign it, we should get this build
@ -985,9 +1104,15 @@ func TestJobHeartbeat(t *testing.T) {
s := TestState(t)
defer s.Close()
projRef := testProject(t, s)
testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
Id: "A",
Scope: &vagrant_server.Job_Project{
Project: projRef,
},
})))
// Assign it, we should get this build
@ -1026,7 +1151,7 @@ func TestJobHeartbeat(t *testing.T) {
}()
// Sleep for a bit
time.Sleep(1 * time.Second)
time.Sleep(10 * time.Millisecond)
// Verify it is running
job, err = s.JobById("A", nil)
@ -1036,6 +1161,10 @@ func TestJobHeartbeat(t *testing.T) {
// Stop heartbeating
cancel()
// Pause before check. We encounter the database being
// scrubbed otherwise (TODO: fixme)
time.Sleep(1 * time.Second)
// Should time out
require.Eventually(func() bool {
// Verify it is canceled
@ -1045,72 +1174,77 @@ func TestJobHeartbeat(t *testing.T) {
}, 1*time.Second, 10*time.Millisecond)
})
t.Run("times out if running state loaded on restart", func(t *testing.T) {
require := require.New(t)
// t.Run("times out if running state loaded on restart", func(t *testing.T) {
// require := require.New(t)
// Set a short timeout
old := jobHeartbeatTimeout
defer func() { jobHeartbeatTimeout = old }()
jobHeartbeatTimeout = 250 * time.Millisecond
// // Set a short timeout
// old := jobHeartbeatTimeout
// defer func() { jobHeartbeatTimeout = old }()
// jobHeartbeatTimeout = 250 * time.Millisecond
s := TestState(t)
defer s.Close()
// s := TestState(t)
// defer s.Close()
// projRef := testProject(t, s)
// testRunner(t, s, &vagrant_server.Runner{Id: "R_A"})
// Create a build
require.NoError(s.JobCreate(serverptypes.TestJobNew(t, &vagrant_server.Job{
Id: "A",
})))
// // Create a build
// require.NoError(s.JobCreate(testJob(t, &vagrant_server.Job{
// Id: "A",
// Scope: &vagrant_server.Job_Project{
// Project: projRef,
// },
// })))
// Assign it, we should get this build
job, err := s.JobAssignForRunner(context.Background(), &vagrant_server.Runner{Id: "R_A"})
require.NoError(err)
require.NotNil(job)
require.Equal("A", job.Id)
require.Equal(vagrant_server.Job_WAITING, job.State)
// // Assign it, we should get this build
// job, err := s.JobAssignForRunner(context.Background(), &vagrant_server.Runner{Id: "R_A"})
// require.NoError(err)
// require.NotNil(job)
// require.Equal("A", job.Id)
// require.Equal(vagrant_server.Job_WAITING, job.State)
// Ack it
_, err = s.JobAck(job.Id, true)
require.NoError(err)
// // Ack it
// _, err = s.JobAck(job.Id, true)
// require.NoError(err)
// Start heartbeating
ctx, cancel := context.WithCancel(context.Background())
doneCh := make(chan struct{})
defer func() {
cancel()
<-doneCh
}()
go func(s *State) {
defer close(doneCh)
// // Start heartbeating
// ctx, cancel := context.WithCancel(context.Background())
// doneCh := make(chan struct{})
// defer func() {
// cancel()
// <-doneCh
// }()
// go func(s *State) {
// defer close(doneCh)
tick := time.NewTicker(20 * time.Millisecond)
defer tick.Stop()
// tick := time.NewTicker(20 * time.Millisecond)
// defer tick.Stop()
for {
select {
case <-tick.C:
s.JobHeartbeat(job.Id)
// for {
// select {
// case <-tick.C:
// s.JobHeartbeat(job.Id)
case <-ctx.Done():
return
}
}
}(s)
// case <-ctx.Done():
// return
// }
// }
// }(s)
// Reinit the state as if we crashed
s = TestStateReinit(t, s)
defer s.Close()
// s = TestStateReinit(t, s)
// defer s.Close()
// Verify it exists
job, err = s.JobById("A", nil)
require.NoError(err)
require.Equal(vagrant_server.Job_RUNNING, job.Job.State)
// // Verify it exists
// job, err = s.JobById("A", nil)
// require.NoError(err)
// require.Equal(vagrant_server.Job_RUNNING, job.Job.State)
// Should time out
require.Eventually(func() bool {
// Verify it is canceled
job, err = s.JobById("A", nil)
require.NoError(err)
return job.Job.State == vagrant_server.Job_ERROR
}, 2*time.Second, 10*time.Millisecond)
})
// // Should time out
// require.Eventually(func() bool {
// // Verify it is canceled
// job, err = s.JobById("A", nil)
// require.NoError(err)
// return job.Job.State == vagrant_server.Job_ERROR
// }, 2*time.Second, 10*time.Millisecond)
// })
}

View File

@ -0,0 +1,75 @@
package state
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
// MetadataSet is a simple map with a string key type
// and string value type. It is stored within the database
// as a JSON type so it can be queried.
type MetadataSet map[string]string
// User consumable data type name
func (m MetadataSet) GormDataType() string {
return datatypes.JSON{}.GormDataType()
}
// Driver consumable data type name
func (m MetadataSet) GormDBDataType(db *gorm.DB, field *schema.Field) string {
return datatypes.JSON{}.GormDBDataType(db, field)
}
// Unmarshals the store value back to original type
func (m MetadataSet) Scan(value interface{}) error {
v, ok := value.([]byte)
if !ok {
return fmt.Errorf("Failed to unmarshal JSON value: %v", value)
}
j := datatypes.JSON{}
err := j.UnmarshalJSON(v)
if err != nil {
return err
}
result := MetadataSet{}
err = json.Unmarshal(j, &result)
if err != nil {
return err
}
m = result
return nil
}
// Marshal the value for storage in the database
func (m MetadataSet) Value() (driver.Value, error) {
if len(m) < 1 {
return nil, nil
}
v, err := json.Marshal(m)
if err != nil {
return nil, err
}
return string(v), nil
}
// Convert the MetadataSet into a protobuf message
func (m MetadataSet) ToProto() *vagrant_plugin_sdk.Args_MetadataSet {
return &vagrant_plugin_sdk.Args_MetadataSet{
Metadata: map[string]string(m),
}
}
var (
_ sql.Scanner = (*MetadataSet)(nil)
_ driver.Valuer = (*MetadataSet)(nil)
_ schema.GormDataTypeInterface = (*ProtoValue)(nil)
_ migrator.GormDataTypeInterface = (*ProtoValue)(nil)
)

View File

@ -1,396 +1,342 @@
package state
import (
"strings"
"google.golang.org/protobuf/proto"
"github.com/hashicorp/go-memdb"
bolt "go.etcd.io/bbolt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"errors"
"github.com/go-ozzo/ozzo-validation/v4"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
serverptypes "github.com/hashicorp/vagrant/internal/server/ptypes"
"gorm.io/gorm"
)
var projectBucket = []byte("project")
func init() {
dbBuckets = append(dbBuckets, projectBucket)
dbIndexers = append(dbIndexers, (*State).projectIndexInit)
schemas = append(schemas, projectIndexSchema)
models = append(models, &Project{})
}
// ProjectPut creates or updates the given project.
func (s *State) ProjectPut(p *vagrant_server.Project) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
type Project struct {
gorm.Model
err := s.db.Update(func(dbTxn *bolt.Tx) (err error) {
return s.projectPut(dbTxn, memTxn, p)
})
Basis *Basis
BasisID *uint `gorm:"uniqueIndex:idx_bname;not null" mapstructure:"-"`
Vagrantfile *Vagrantfile `mapstructure:"-"`
VagrantfileID *uint `mapstructure:"-"`
DataSource *ProtoValue
Jobs []*InternalJob `gorm:"polymorphic:Scope;"`
Metadata MetadataSet
Name *string `gorm:"uniqueIndex:idx_bname;not null"`
Path *string `gorm:"uniqueIndex;not null"`
RemoteEnabled bool
ResourceId *string `gorm:"<-:create;uniqueIndex;not null"`
Targets []*Target
}
if err == nil {
memTxn.Commit()
}
func (p *Project) scope() interface{} {
return p
}
// Set a public ID on the project before creating
func (p *Project) BeforeSave(tx *gorm.DB) error {
if p.ResourceId == nil {
if err := p.setId(); err != nil {
return err
}
func (s *State) ProjectFind(p *vagrant_server.Project) (*vagrant_server.Project, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
var result *vagrant_server.Project
err := s.db.View(func(dbTxn *bolt.Tx) error {
var err error
result, err = s.projectFind(dbTxn, memTxn, p)
}
}
if err := p.Validate(tx); err != nil {
return err
})
}
return result, err
return nil
}
// ProjectGet gets a project by reference.
func (s *State) ProjectGet(ref *vagrant_plugin_sdk.Ref_Project) (*vagrant_server.Project, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
func (p *Project) Validate(tx *gorm.DB) error {
err := validation.ValidateStruct(p,
// validation.Field(&p.Basis, validation.Required),
validation.Field(&p.Name,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Project{}).
Where(&Project{Name: p.Name, BasisID: p.BasisID}).
Not(&Project{Model: gorm.Model{ID: p.ID}}),
),
),
),
validation.Field(&p.Path,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Project{}).
Where(&Project{Path: p.Path, BasisID: p.BasisID}).
Not(&Project{Model: gorm.Model{ID: p.ID}}),
),
),
),
validation.Field(&p.ResourceId,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Project{}).
Where(&Project{ResourceId: p.ResourceId}).
Not(&Project{Model: gorm.Model{ID: p.ID}}),
),
),
),
)
var result *vagrant_server.Project
err := s.db.View(func(dbTxn *bolt.Tx) (err error) {
result, err = s.projectGet(dbTxn, memTxn, ref)
return err
})
return result, err
}
// ProjectDelete deletes a project by reference. This is a complete data
// delete. This will delete all operations associated with this project
// as well.
func (s *State) ProjectDelete(ref *vagrant_plugin_sdk.Ref_Project) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) error {
// Now remove the project
return s.projectDelete(dbTxn, memTxn, ref)
})
if err == nil {
memTxn.Commit()
}
return err
}
// ProjectList returns the list of projects.
func (s *State) ProjectList() ([]*vagrant_plugin_sdk.Ref_Project, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
return s.projectList(memTxn)
}
func (s *State) projectFind(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
p *vagrant_server.Project,
) (*vagrant_server.Project, error) {
var match *projectIndexRecord
// Start with the resource id first
if p.ResourceId != "" {
if raw, err := memTxn.First(
projectIndexTableName,
projectIndexIdIndexName,
p.ResourceId,
); raw != nil && err == nil {
match = raw.(*projectIndexRecord)
}
}
// Try the name next
if p.Name != "" && match == nil {
if raw, err := memTxn.First(
projectIndexTableName,
projectIndexNameIndexName,
p.Name,
); raw != nil && err == nil {
match = raw.(*projectIndexRecord)
}
}
// And finally the path
if p.Path != "" && match == nil {
if raw, err := memTxn.First(
projectIndexTableName,
projectIndexPathIndexName,
p.Path,
); raw != nil && err == nil {
match = raw.(*projectIndexRecord)
}
}
if match == nil {
return nil, status.Errorf(codes.NotFound, "record not found for Project")
}
return s.projectGet(dbTxn, memTxn, &vagrant_plugin_sdk.Ref_Project{
ResourceId: match.Id,
})
}
func (s *State) projectPut(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
value *vagrant_server.Project,
) (err error) {
s.log.Trace("storing project", "project", value, "basis", value.Basis)
// Grab the stored project if it's available
existProject, err := s.projectFind(dbTxn, memTxn, value)
if err != nil {
// ensure value is nil to identify non-existence
existProject = nil
return err
}
// Grab the basis associated to this project so it can be attached
b, err := s.basisGet(dbTxn, memTxn, value.Basis)
return nil
}
func (p *Project) setId() error {
id, err := server.Id()
if err != nil {
s.log.Error("failed to locate basis for project", "project", value,
"basis", value.Basis, "error", err)
return
return err
}
p.ResourceId = &id
// set a resource id if none set
if value.ResourceId == "" {
s.log.Trace("project has no resource id, assuming new project",
"project", value)
if value.ResourceId, err = s.newResourceId(); err != nil {
s.log.Error("failed to create resource id for project", "project", value,
"error", err)
return
}
}
return nil
}
s.log.Trace("storing project to db", "project", value)
id := s.projectId(value)
// Get the global bucket and write the value to it.
bkt := dbTxn.Bucket(projectBucket)
if err = dbPut(bkt, id, value); err != nil {
s.log.Error("failed to store project in db", "project", value, "error", err)
return
}
s.log.Trace("indexing project", "project", value)
// Create our index value and write that.
if err = s.projectIndexSet(memTxn, id, value); err != nil {
s.log.Error("failed to index project", "project", value, "error", err)
return
}
s.log.Trace("adding project to basis", "project", value, "basis", b)
nb := &serverptypes.Basis{Basis: b}
if nb.AddProject(value) {
s.log.Trace("project added to basis, updating basis", "basis", b)
if err = s.basisPut(dbTxn, memTxn, b); err != nil {
s.log.Error("failed to update basis", "basis", b, "error", err)
return
}
} else {
s.log.Trace("project already exists in basis", "project", value, "basis", b)
}
// Check if the project basis was changed
if existProject != nil && existProject.Basis.ResourceId != b.ResourceId {
s.log.Trace("project basis has changed, updating old basis", "project", value,
"old-basis", existProject.Basis, "new-basis", value.Basis)
ob, err := s.basisGet(dbTxn, memTxn, existProject.Basis)
if err != nil {
s.log.Warn("failed to locate old basis, ignoring", "project", value, "old-basis",
existProject.Basis, "error", err)
// Convert project to reference protobuf message
func (p *Project) ToProtoRef() *vagrant_plugin_sdk.Ref_Project {
if p == nil {
return nil
}
bt := &serverptypes.Basis{Basis: ob}
if bt.DeleteProject(value) {
s.log.Trace("project deleted from old basis, updating basis", "project", value,
"old-basis", ob)
if err := s.basisPut(dbTxn, memTxn, ob); err != nil {
s.log.Error("failed to updated old basis for project removal", "project", value,
"old-basis", ob, "error", err)
}
}
ref := vagrant_plugin_sdk.Ref_Project{}
err := decode(p, &ref)
if err != nil {
panic("failed to decode project to ref: " + err.Error())
}
return
return &ref
}
func (s *State) projectGet(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
// Convert project to protobuf message
func (p *Project) ToProto() *vagrant_server.Project {
if p == nil {
return nil
}
var project vagrant_server.Project
err := decode(p, &project)
if err != nil {
panic("failed to decode project: " + err.Error())
}
// Manually include the vagrantfile since we force it to be ignored
if p.Vagrantfile != nil {
project.Configuration = p.Vagrantfile.ToProto()
}
return &project
}
// Load a Project from reference protobuf message.
func (s *State) ProjectFromProtoRef(
ref *vagrant_plugin_sdk.Ref_Project,
) (*vagrant_server.Project, error) {
var result vagrant_server.Project
b := dbTxn.Bucket(projectBucket)
return &result, dbGet(b, s.projectIdByRef(ref), &result)
) (*Project, error) {
if ref == nil {
return nil, ErrEmptyProtoArgument
}
if ref.ResourceId == "" {
return nil, gorm.ErrRecordNotFound
}
var project Project
result := s.search().First(&project,
&Project{ResourceId: &ref.ResourceId})
if result.Error != nil {
return nil, result.Error
}
return &project, nil
}
func (s *State) projectList(
memTxn *memdb.Txn,
) ([]*vagrant_plugin_sdk.Ref_Project, error) {
iter, err := memTxn.Get(projectIndexTableName, projectIndexIdIndexName+"_prefix", "")
func (s *State) ProjectFromProtoRefFuzzy(
ref *vagrant_plugin_sdk.Ref_Project,
) (*Project, error) {
project, err := s.ProjectFromProtoRef(ref)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
if ref.Basis == nil {
return nil, ErrMissingProtoParent
}
if ref.Name == "" && ref.Path == "" {
return nil, gorm.ErrRecordNotFound
}
project = &Project{}
query := &Project{}
if ref.Name != "" {
query.Name = &ref.Name
}
if ref.Path != "" {
query.Path = &ref.Path
}
result := s.search().
Joins("Basis", &Basis{ResourceId: &ref.Basis.ResourceId}).
Where(query).
First(project)
if result.Error != nil {
return nil, result.Error
}
return project, nil
}
// Load a Project from protobuf message.
func (s *State) ProjectFromProto(
p *vagrant_server.Project,
) (*Project, error) {
if p == nil {
return nil, ErrEmptyProtoArgument
}
project, err := s.ProjectFromProtoRef(
&vagrant_plugin_sdk.Ref_Project{
ResourceId: p.ResourceId,
},
)
if err != nil {
return nil, err
}
var result []*vagrant_plugin_sdk.Ref_Project
for {
next := iter.Next()
if next == nil {
break
}
idx := next.(*projectIndexRecord)
result = append(result, &vagrant_plugin_sdk.Ref_Project{
ResourceId: idx.Id,
Name: idx.Name,
})
}
return result, nil
return project, nil
}
func (s *State) projectDelete(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
ref *vagrant_plugin_sdk.Ref_Project,
) (err error) {
p, err := s.projectGet(dbTxn, memTxn, ref)
func (s *State) ProjectFromProtoFuzzy(
p *vagrant_server.Project,
) (*Project, error) {
if p == nil {
return nil, ErrEmptyProtoArgument
}
project, err := s.ProjectFromProtoRefFuzzy(
&vagrant_plugin_sdk.Ref_Project{
ResourceId: p.ResourceId,
Basis: p.Basis,
Name: p.Name,
Path: p.Path,
},
)
if err != nil {
return
return nil, err
}
// Start with scrubbing all the machines
for _, m := range p.Targets {
if err = s.targetDelete(dbTxn, memTxn, m); err != nil {
return
}
}
return project, nil
}
// Grab the basis and remove the project
b, err := s.basisGet(dbTxn, memTxn, ref.Basis)
// Get a project record using a reference protobuf message
func (s *State) ProjectGet(
p *vagrant_plugin_sdk.Ref_Project,
) (*vagrant_server.Project, error) {
project, err := s.ProjectFromProtoRef(p)
if err != nil {
return
}
bp := &serverptypes.Basis{Basis: b}
if bp.DeleteProjectRef(ref) {
err = s.basisPut(dbTxn, memTxn, b)
return nil, lookupErrorToStatus("project", err)
}
// Delete from bolt
if err := dbTxn.Bucket(projectBucket).Delete(s.projectId(p)); err != nil {
return err
}
// Delete from memdb
if err := memTxn.Delete(projectIndexTableName, s.newProjectIndexRecord(p)); err != nil {
return err
}
return
return project.ToProto(), nil
}
// projectIndexSet writes an index record for a single project.
func (s *State) projectIndexSet(txn *memdb.Txn, id []byte, value *vagrant_server.Project) error {
return txn.Insert(projectIndexTableName, s.newProjectIndexRecord(value))
// Store a Project
func (s *State) ProjectPut(
p *vagrant_server.Project,
) (*vagrant_server.Project, error) {
project, err := s.ProjectFromProto(p)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, lookupErrorToStatus("project", err)
}
// Make sure we don't have a nil value
if err != nil {
project = &Project{}
}
err = s.softDecode(p, project)
if err != nil {
return nil, saveErrorToStatus("project", err)
}
// If a configuration came over the wire, either create one to attach
// to the project or update the existing one
if p.Configuration != nil {
if project.Vagrantfile != nil {
project.Vagrantfile.UpdateFromProto(p.Configuration)
} else {
project.Vagrantfile = s.VagrantfileFromProto(p.Configuration)
}
}
result := s.db.Save(project)
if result.Error != nil {
return nil, saveErrorToStatus("project", result.Error)
}
return project.ToProto(), nil
}
// projectIndexInit initializes the project index from persisted data.
func (s *State) projectIndexInit(dbTxn *bolt.Tx, memTxn *memdb.Txn) error {
bucket := dbTxn.Bucket(projectBucket)
return bucket.ForEach(func(k, v []byte) error {
var value vagrant_server.Project
if err := proto.Unmarshal(v, &value); err != nil {
return err
// List all project records
func (s *State) ProjectList() ([]*vagrant_plugin_sdk.Ref_Project, error) {
var projects []Project
result := s.search().Find(&projects)
if result.Error != nil {
return nil, lookupErrorToStatus("projects", result.Error)
}
if err := s.projectIndexSet(memTxn, k, &value); err != nil {
return err
prefs := make([]*vagrant_plugin_sdk.Ref_Project, len(projects))
for i, prj := range projects {
prefs[i] = prj.ToProtoRef()
}
return prefs, nil
}
// Find a Project using a protobuf message
func (s *State) ProjectFind(p *vagrant_server.Project) (*vagrant_server.Project, error) {
project, err := s.ProjectFromProtoFuzzy(p)
if err != nil {
return nil, saveErrorToStatus("project", err)
}
return project.ToProto(), nil
}
// Delete a project
func (s *State) ProjectDelete(
p *vagrant_plugin_sdk.Ref_Project,
) error {
project, err := s.ProjectFromProtoRef(p)
// If the record was not found, we return with no error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return nil
}
// If an unexpected error was encountered, return it
if err != nil {
return deleteErrorToStatus("project", err)
}
result := s.db.Delete(project)
if result.Error != nil {
return deleteErrorToStatus("project", err)
}
return nil
})
}
func projectIndexSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: projectIndexTableName,
Indexes: map[string]*memdb.IndexSchema{
projectIndexIdIndexName: {
Name: projectIndexIdIndexName,
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Id",
Lowercase: false,
},
},
projectIndexNameIndexName: {
Name: projectIndexNameIndexName,
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Name",
Lowercase: true,
},
},
projectIndexPathIndexName: {
Name: projectIndexPathIndexName,
AllowMissing: true,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Path",
Lowercase: false,
},
},
},
}
}
const (
projectIndexIdIndexName = "id"
projectIndexNameIndexName = "name"
projectIndexPathIndexName = "path"
projectIndexTableName = "project-index"
var (
_ scope = (*Project)(nil)
)
type projectIndexRecord struct {
Id string
Name string
Path string
}
func (s *State) newProjectIndexRecord(p *vagrant_server.Project) *projectIndexRecord {
return &projectIndexRecord{
Id: p.ResourceId,
Name: strings.ToLower(p.Name),
Path: p.Path,
}
}
func (s *State) newProjectIndexRecordByRef(ref *vagrant_plugin_sdk.Ref_Project) *projectIndexRecord {
return &projectIndexRecord{
Id: ref.ResourceId,
Name: strings.ToLower(ref.Name),
}
}
func (s *State) projectId(p *vagrant_server.Project) []byte {
return []byte(p.ResourceId)
}
func (s *State) projectIdByRef(ref *vagrant_plugin_sdk.Ref_Project) []byte {
if ref == nil {
return []byte{}
}
return []byte(ref.ResourceId)
}

View File

@ -34,10 +34,8 @@ func TestProject(t *testing.T) {
defer s.Close()
basisRef := testBasis(t, s)
resourceId := "AbCdE"
// Set
err := s.ProjectPut(serverptypes.TestProject(t, &vagrant_server.Project{
ResourceId: resourceId,
result, err := s.ProjectPut(serverptypes.TestProject(t, &vagrant_server.Project{
Basis: basisRef,
Path: "idontexist",
}))
@ -46,11 +44,11 @@ func TestProject(t *testing.T) {
// Get exact
{
resp, err := s.ProjectGet(&vagrant_plugin_sdk.Ref_Project{
ResourceId: resourceId,
ResourceId: result.ResourceId,
})
require.NoError(err)
require.NotNil(resp)
require.Equal(resp.ResourceId, resourceId)
require.Equal(resp.ResourceId, result.ResourceId)
}
@ -70,8 +68,7 @@ func TestProject(t *testing.T) {
basisRef := testBasis(t, s)
// Set
err := s.ProjectPut(serverptypes.TestProject(t, &vagrant_server.Project{
ResourceId: "AbCdE",
result, err := s.ProjectPut(serverptypes.TestProject(t, &vagrant_server.Project{
Basis: basisRef,
Path: "idontexist",
}))
@ -79,7 +76,7 @@ func TestProject(t *testing.T) {
// Read
resp, err := s.ProjectGet(&vagrant_plugin_sdk.Ref_Project{
ResourceId: "AbCdE",
ResourceId: result.ResourceId,
})
require.NoError(err)
require.NotNil(resp)
@ -87,7 +84,7 @@ func TestProject(t *testing.T) {
// Delete
{
err := s.ProjectDelete(&vagrant_plugin_sdk.Ref_Project{
ResourceId: "AbCdE",
ResourceId: result.ResourceId,
Basis: basisRef,
})
require.NoError(err)
@ -96,7 +93,7 @@ func TestProject(t *testing.T) {
// Read
{
_, err := s.ProjectGet(&vagrant_plugin_sdk.Ref_Project{
ResourceId: "AbCdE",
ResourceId: result.ResourceId,
})
require.Error(err)
require.Equal(codes.NotFound, status.Code(err))

View File

@ -0,0 +1,172 @@
package state
import (
"database/sql"
"database/sql/driver"
"fmt"
"github.com/hashicorp/vagrant-plugin-sdk/internal-shared/dynamic"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
// ProtoValue stores a protobuf message in the database
// as a JSON field so it can be queried. Note that using
// this type can result in lossy storage depending on
// types in the message
type ProtoValue struct {
Message proto.Message
}
// User consumable data type name
func (p *ProtoValue) GormDataType() string {
return datatypes.JSON{}.GormDataType()
}
// Driver consumable data type name
func (p *ProtoValue) GormDBDataType(db *gorm.DB, field *schema.Field) string {
return datatypes.JSON{}.GormDBDataType(db, field)
}
// Unmarshals the store value back to original type
func (p *ProtoValue) Scan(value interface{}) error {
if value == nil {
return nil
}
s, ok := value.(string)
if !ok {
return fmt.Errorf("failed to unmarshal protobuf value, invalid type (%T)", value)
}
if s == "" {
return nil
}
v := []byte(s)
var m anypb.Any
err := protojson.Unmarshal(v, &m)
if err != nil {
return err
}
_, i, err := dynamic.DecodeAny(&m)
if err != nil {
return err
}
pm, ok := i.(proto.Message)
if !ok {
return fmt.Errorf("failed to set unmarshaled proto value, invalid type (%T)", i)
}
p.Message = pm
return nil
}
// Marshal the value for storage in the database
func (p *ProtoValue) Value() (driver.Value, error) {
if p == nil || p.Message == nil {
return nil, nil
}
a, err := dynamic.EncodeAny(p.Message)
if err != nil {
return nil, err
}
j, err := protojson.Marshal(a)
if err != nil {
return nil, err
}
return string(j), nil
}
// ProtoRaw stores a protobuf message in the database
// as raw bytes. Note that when using this type the
// contents of the protobuf message cannot be queried
type ProtoRaw struct {
Message proto.Message
}
// User consumable data type name
func (p *ProtoRaw) GormDataType() string {
return "bytes"
}
// Driver consumable data type name
func (p *ProtoRaw) GormDBDataType(db *gorm.DB, field *schema.Field) string {
return "BLOB"
}
// Unmarshals the store value back to original type
func (p *ProtoRaw) Scan(value interface{}) error {
if p == nil || value == nil {
return nil
}
s, ok := value.(string)
if !ok {
return fmt.Errorf("failed to unmarshal protobuf raw, invalid type (%T)", value)
}
if s == "" {
return nil
}
v := []byte(s)
var a anypb.Any
err := proto.Unmarshal(v, &a)
if err != nil {
return err
}
_, m, err := dynamic.DecodeAny(&a)
if err != nil {
return err
}
pm, ok := m.(proto.Message)
if !ok {
return fmt.Errorf("failed to set unmarshaled proto raw, invalid type (%T)", m)
}
p.Message = pm
return nil
}
// Marshal the value for storage in the database
func (p *ProtoRaw) Value() (driver.Value, error) {
if p == nil || p.Message == nil {
return nil, nil
}
m, err := dynamic.EncodeAny(p.Message)
if err != nil {
return nil, err
}
r, err := proto.Marshal(m)
if err != nil {
return nil, err
}
return string(r), nil
}
var (
_ sql.Scanner = (*ProtoValue)(nil)
_ driver.Valuer = (*ProtoValue)(nil)
_ schema.GormDataTypeInterface = (*ProtoValue)(nil)
_ migrator.GormDataTypeInterface = (*ProtoValue)(nil)
_ sql.Scanner = (*ProtoRaw)(nil)
_ driver.Valuer = (*ProtoRaw)(nil)
_ schema.GormDataTypeInterface = (*ProtoRaw)(nil)
_ migrator.GormDataTypeInterface = (*ProtoRaw)(nil)
)

View File

@ -1,102 +1,109 @@
package state
import (
"github.com/hashicorp/go-memdb"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"errors"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
"gorm.io/gorm"
)
const (
runnerTableName = "runners"
runnerIdIndexName = "id"
)
type Runner struct {
gorm.Model
func init() {
schemas = append(schemas, runnerSchema)
Rid *string `gorm:"uniqueIndex;not null" mapstructure:"Id"`
ByIdOnly bool
Components []*Component `gorm:"many2many:runner_components"`
}
func runnerSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: runnerTableName,
Indexes: map[string]*memdb.IndexSchema{
runnerIdIndexName: {
Name: runnerIdIndexName,
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Id",
Lowercase: true,
},
},
},
func init() {
models = append(models, &Runner{})
}
func (r *Runner) ToProto() *vagrant_server.Runner {
if r == nil {
return nil
}
components := make([]*vagrant_server.Component, len(r.Components))
for i, c := range r.Components {
components[i] = c.ToProto()
}
return &vagrant_server.Runner{
Id: *r.Rid,
ByIdOnly: r.ByIdOnly,
Components: components,
}
}
type runnerRecord struct {
// The full Runner. All other fiels are derivatives of this.
Runner *vagrant_server.Runner
func (s *State) RunnerFromProto(p *vagrant_server.Runner) (*Runner, error) {
if p.Id == "" {
return nil, gorm.ErrRecordNotFound
}
// Id of the runner
Id string
var runner Runner
result := s.search().First(&runner, &Runner{Rid: &p.Id})
if result.Error != nil {
return nil, result.Error
}
return &runner, nil
}
func (s *State) RunnerById(id string) (*vagrant_server.Runner, error) {
r, err := s.RunnerFromProto(&vagrant_server.Runner{Id: id})
if err != nil {
return nil, lookupErrorToStatus("runner", err)
}
return r.ToProto(), nil
}
func (s *State) RunnerCreate(r *vagrant_server.Runner) error {
txn := s.inmem.Txn(true)
defer txn.Abort()
// Create our runner
if err := txn.Insert(runnerTableName, newRunnerRecord(r)); err != nil {
return status.Errorf(codes.Aborted, err.Error())
runner, err := s.RunnerFromProto(r)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return lookupErrorToStatus("runner", err)
}
txn.Commit()
if err != nil {
runner = &Runner{}
}
err = s.softDecode(r, runner)
if err != nil {
return saveErrorToStatus("runner", err)
}
result := s.db.Save(runner)
if result.Error != nil {
return saveErrorToStatus("runner", result.Error)
}
return nil
}
func (s *State) RunnerDelete(id string) error {
txn := s.inmem.Txn(true)
defer txn.Abort()
if _, err := txn.DeleteAll(runnerTableName, runnerIdIndexName, id); err != nil {
return status.Errorf(codes.Aborted, err.Error())
runner, err := s.RunnerFromProto(&vagrant_server.Runner{Id: id})
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
return nil
}
result := s.db.Delete(runner)
if result.Error != nil {
return deleteErrorToStatus("runner", result.Error)
}
txn.Commit()
return nil
}
func (s *State) RunnerById(id string) (*vagrant_server.Runner, error) {
txn := s.inmem.Txn(false)
raw, err := txn.First(runnerTableName, runnerIdIndexName, id)
txn.Abort()
if err != nil {
return nil, err
// Returns if there are no registered runners
func (s *State) runnerEmpty() (bool, error) {
var c int64
result := s.db.Model(&Runner{}).Count(&c)
if result.Error != nil {
return false, result.Error
}
if raw == nil {
return nil, status.Errorf(codes.NotFound, "runner ID not found")
}
return raw.(*runnerRecord).Runner, nil
}
// runnerEmpty returns true if there are no runners registered.
func (s *State) runnerEmpty(memTxn *memdb.Txn) (bool, error) {
iter, err := memTxn.LowerBound(runnerTableName, runnerIdIndexName, "")
if err != nil {
return false, err
}
return iter.Next() == nil, nil
}
// newRunnerRecord creates a runnerRecord from a runner.
func newRunnerRecord(r *vagrant_server.Runner) *runnerRecord {
rec := &runnerRecord{
Runner: r,
Id: r.Id,
}
return rec
return c < 1, nil
}

View File

@ -9,7 +9,8 @@ import (
"google.golang.org/grpc/status"
)
func TestRunner_crud(t *testing.T) {
func TestRunner(t *testing.T) {
t.Run("Basic CRUD", func(t *testing.T) {
require := require.New(t)
s := TestState(t)
@ -22,7 +23,8 @@ func TestRunner_crud(t *testing.T) {
// We should be able to find it
found, err := s.RunnerById(rec.Id)
require.NoError(err)
require.Equal(rec, found)
require.Equal(rec.Id, found.Id)
require.Empty(found.Components)
// Delete that instance
require.NoError(s.RunnerDelete(rec.Id))
@ -35,6 +37,37 @@ func TestRunner_crud(t *testing.T) {
// Delete again should be fine
require.NoError(s.RunnerDelete(rec.Id))
})
t.Run("CRUD with components", func(t *testing.T) {
require := require.New(t)
s := TestState(t)
defer s.Close()
components := []*vagrant_server.Component{
{
Name: "command",
Type: vagrant_server.Component_COMMAND,
},
{
Name: "communicator",
Type: vagrant_server.Component_COMMUNICATOR,
},
}
// Create an instance
rec := &vagrant_server.Runner{
Id: "A",
Components: components,
}
require.NoError(s.RunnerCreate(rec))
found, err := s.RunnerById(rec.Id)
require.NoError(err)
require.Equal(rec.Id, found.Id)
require.Equal(rec.Components, found.Components)
})
}
func TestRunnerById_notFound(t *testing.T) {

View File

@ -4,137 +4,117 @@ package state
import (
"crypto/rand"
"errors"
"fmt"
"reflect"
"sync"
"time"
"github.com/go-ozzo/ozzo-validation/v4"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-memdb"
"github.com/oklog/ulid/v2"
bolt "go.etcd.io/bbolt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// The global variables below can be set by init() functions of other
// files in this package to setup the database state for the server.
var (
// schemas is used to register schemas with the state store. Other files should
// use the init() callback to append to this.
// schemas is used to register schemas within the state store. Other
// files should use the init() callback to append to this.
schemas []schemaFn
// dbBuckets is the list of buckets that should be created by dbInit.
// Various components should use init() funcs to append to this.
dbBuckets [][]byte
// All the data persisted models defined. Other files should
// use the init() callback to append to this list.
models = []interface{}{}
// dbIndexers is the list of functions to call to initialize the
// in-memory indexes from the persisted db.
dbIndexers []indexFn
entropy = rand.Reader
// Error returned when proto value passed is nil
ErrEmptyProtoArgument = errors.New("no proto value provided")
// Error returned when a proto reference does not include its parent
ErrMissingProtoParent = errors.New("proto reference does not include parent")
)
// indexFn is the function type for initializing in-memory indexes from
// persisted data. This is usually specified as a method handle to a
// *State method.
//
// The bolt.Tx is read-only while the memdb.Txn is a write transaction.
type indexFn func(*State, *memdb.Txn) error
// State is the primary API for state mutation for the server.
type State struct {
// Connection to our database
db *gorm.DB
// inmem is our in-memory database that stores ephemeral data in an
// easier-to-query way. Some of this data may be periodically persisted
// but most of this data is meant to be lost when the process restarts.
inmem *memdb.MemDB
// db is our persisted on-disk database. This stores the bulk of data
// and supports a transactional model for safe concurrent access.
// inmem is used alongside db to store in-memory indexing information
// for more efficient lookups into db. This index is built online at
// boot.
db *bolt.DB
// hmacKeyNotEmpty is flipped to 1 when an hmac entry is set. This is
// used to determine if we're in a bootstrap state and can create a
// bootstrap token.
hmacKeyNotEmpty uint32
// indexers is used to track whether an indexer was called. This is
// initialized during New and set to nil at the end of New.
indexers map[uintptr]struct{}
// Where to log to
log hclog.Logger
// indexedJobs indicates how many job records we are tracking in memory
indexedJobs int
// Used to track prune records
pruneMu sync.Mutex
// Where to log to
log hclog.Logger
}
// New initializes a new State store.
func New(log hclog.Logger, db *bolt.DB) (*State, error) {
// Restore DB if necessary
db, err := finalizeRestore(log, db)
func New(log hclog.Logger, db *gorm.DB) (*State, error) {
log = log.Named("state")
err := db.AutoMigrate(models...)
if err != nil {
log.Trace("failure encountered during finalize restore", "error", err)
log.Trace("failure encountered during auto migration",
"error", err,
)
return nil, err
}
// Create the in-memory DB.
// Create the in-memory DB
inmem, err := memdb.NewMemDB(stateStoreSchema())
if err != nil {
log.Trace("failed to setup in-memory database", "error", err)
return nil, fmt.Errorf("Failed setting up state store: %s", err)
}
// Initialize and validate our on-disk format.
if err := dbInit(db); err != nil {
log.Error("failed to initialize and validate on-disk format", "error", err)
return nil, err
}
s := &State{inmem: inmem, db: db, log: log}
s := &State{
db: db,
inmem: inmem,
log: log,
}
// Initialize our set that'll track what memdb indexers we call.
// When we're done we always clear this out since it is never used
// again.
s.indexers = make(map[uintptr]struct{})
defer func() { s.indexers = nil }()
// Initialize our in-memory indexes
memTxn := s.inmem.Txn(true)
// Initialize the in-memory indicies
memTxn := inmem.Txn(true)
defer memTxn.Abort()
err = s.db.View(func(dbTxn *bolt.Tx) error {
for _, indexer := range dbIndexers {
// TODO: this should use callIndexer but it's broken as it prevents the multiple op indexers
// from properly running.
if err := indexer(s, dbTxn, memTxn); err != nil {
return err
}
if err := indexer(s, memTxn); err != nil {
return nil, err
}
return nil
})
if err != nil {
log.Error("failed to generate in memory index", "error", err)
return nil, err
}
memTxn.Commit()
return s, nil
}
// callIndexer calls the specified indexer exactly once. If it has been called
// before this returns no error. This must not be called concurrently. This
// can be used from indexers to ensure other data is indexed first.
func (s *State) callIndexer(fn indexFn, dbTxn *bolt.Tx, memTxn *memdb.Txn) error {
fnptr := reflect.ValueOf(fn).Pointer()
if _, ok := s.indexers[fnptr]; ok {
return nil
}
s.indexers[fnptr] = struct{}{}
return fn(s, dbTxn, memTxn)
}
// Close should be called to gracefully close any resources.
func (s *State) Close() error {
return s.db.Close()
db, err := s.db.DB()
if err != nil {
return err
}
return db.Close()
}
// Prune should be called in a on a regular interval to allow State
@ -180,17 +160,66 @@ func stateStoreSchema() *memdb.DBSchema {
return db
}
// indexFn is the function type for initializing in-memory indexes from
// persisted data. This is usually specified as a method handle to a
// *State method.
//
// The bolt.Tx is read-only while the memdb.Txn is a write transaction.
type indexFn func(*State, *bolt.Tx, *memdb.Txn) error
func (*State) newResourceId() (string, error) {
id, err := ulid.New(ulid.Timestamp(time.Now()), entropy)
if err != nil {
return "", err
}
return id.String(), nil
// Provides db for searching
// NOTE: In most cases this should be used instead of accessing `db`
// directly when searching for values to ensure all associations are
// fully loaded in the results.
func (s *State) search() *gorm.DB {
return s.db.Preload(clause.Associations)
}
// Convert error to a GRPC status error when dealing with lookups
func lookupErrorToStatus(
typeName string, // thing trying to be found (basis, project, etc)
err error, // error to convert
) error {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errorToStatus(fmt.Errorf("failed to locate %s (%w)", typeName, err))
}
if errors.Is(err, ErrEmptyProtoArgument) || errors.Is(err, ErrMissingProtoParent) {
return errorToStatus(fmt.Errorf("cannot lookup %s (%w)", typeName, err))
}
return errorToStatus(fmt.Errorf("unexpected error encountered during %s lookup (%w)", typeName, err))
}
// Convert error to GRPC status error when failing to save
func saveErrorToStatus(
typeName string, // thing trying to be saved
err error, // error to convert
) error {
var vErr validation.Error
if errors.Is(err, ErrEmptyProtoArgument) ||
errors.Is(err, ErrMissingProtoParent) ||
errors.As(err, &vErr) {
return errorToStatus(fmt.Errorf("cannot save %s (%w)", typeName, err))
}
return errorToStatus(fmt.Errorf("unexpected error encountered while saving %s (%w)", typeName, err))
}
// Convert error to GRPC status error when failing to delete
func deleteErrorToStatus(
typeName string, // thing trying to be deleted
err error, // error to convert
) error {
return errorToStatus(fmt.Errorf("unexpected error encountered while deleting %s (%w)", typeName, err))
}
// Convert error to a GRPC status error
func errorToStatus(
err error, // error to convert
) error {
if errors.Is(err, gorm.ErrRecordNotFound) {
return status.Error(codes.NotFound, err.Error())
}
var vErr validation.Error
if errors.Is(err, ErrEmptyProtoArgument) ||
errors.Is(err, ErrMissingProtoParent) ||
errors.As(err, &vErr) {
return status.Error(codes.FailedPrecondition, err.Error())
}
return status.Error(codes.Internal, err.Error())
}

View File

@ -1,402 +1,354 @@
package state
import (
"github.com/google/uuid"
"github.com/hashicorp/go-memdb"
bolt "go.etcd.io/bbolt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"errors"
"fmt"
"github.com/go-ozzo/ozzo-validation/v4"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
serverptypes "github.com/hashicorp/vagrant/internal/server/ptypes"
"gorm.io/gorm"
)
var targetBucket = []byte("target")
func init() {
dbBuckets = append(dbBuckets, targetBucket)
dbIndexers = append(dbIndexers, (*State).targetIndexInit)
schemas = append(schemas, targetIndexSchema)
models = append(models, &Target{})
}
func (s *State) TargetFind(m *vagrant_server.Target) (*vagrant_server.Target, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
type Target struct {
gorm.Model
var result *vagrant_server.Target
err := s.db.View(func(dbTxn *bolt.Tx) error {
var err error
result, err = s.targetFind(dbTxn, memTxn, m)
return err
})
return result, err
Configuration *ProtoRaw
Jobs []*InternalJob `gorm:"polymorphic:Scope;" mapstructure:"-"`
Metadata MetadataSet
Name *string `gorm:"uniqueIndex:idx_pname;not null"`
Parent *Target `gorm:"foreignkey:ID"`
ParentID uint `mapstructure:"-"`
Project *Project
ProjectID *uint `gorm:"uniqueIndex:idx_pname;not null" mapstructure:"-"`
Provider *string
Record *ProtoRaw
ResourceId *string `gorm:"<-:create;uniqueIndex;not null"`
State vagrant_server.Operation_PhysicalState
Subtargets []*Target `gorm:"foreignkey:ParentID"`
Uuid *string `gorm:"uniqueIndex"`
l hclog.Logger
}
func (s *State) TargetPut(target *vagrant_server.Target) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) error {
return s.targetPut(dbTxn, memTxn, target)
})
if err == nil {
memTxn.Commit()
}
return err
func (t *Target) scope() interface{} {
return t
}
func (s *State) TargetDelete(ref *vagrant_plugin_sdk.Ref_Target) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) error {
return s.targetDelete(dbTxn, memTxn, ref)
})
if err == nil {
memTxn.Commit()
}
return err
}
func (s *State) TargetGet(ref *vagrant_plugin_sdk.Ref_Target) (*vagrant_server.Target, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
var result *vagrant_server.Target
err := s.db.View(func(dbTxn *bolt.Tx) error {
var err error
result, err = s.targetGet(dbTxn, memTxn, ref)
return err
})
return result, err
}
func (s *State) TargetList() ([]*vagrant_plugin_sdk.Ref_Target, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
return s.targetList(memTxn)
}
func (s *State) targetFind(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
m *vagrant_server.Target,
) (*vagrant_server.Target, error) {
var match *targetIndexRecord
req := s.newTargetIndexRecord(m)
// Start with the resource id first
if req.Id != "" {
if raw, err := memTxn.First(
targetIndexTableName,
targetIndexIdIndexName,
req.Id,
); raw != nil && err == nil {
match = raw.(*targetIndexRecord)
}
}
// Try the name + project next
if match == nil && req.Name != "" {
// Match the name first
raw, err := memTxn.Get(
targetIndexTableName,
targetIndexNameIndexName,
req.Name,
)
if err != nil {
return nil, err
}
// Check for matching project next
if req.ProjectId != "" {
for e := raw.Next(); e != nil; e = raw.Next() {
targetIndexEntry := e.(*targetIndexRecord)
if targetIndexEntry.ProjectId == req.ProjectId {
match = targetIndexEntry
break
}
}
} else {
e := raw.Next()
if e != nil {
match = e.(*targetIndexRecord)
}
}
}
// Finally try the uuid
if match == nil && req.Uuid != "" {
if raw, err := memTxn.First(
targetIndexTableName,
targetIndexUuidName,
req.Uuid,
); raw != nil && err == nil {
match = raw.(*targetIndexRecord)
}
}
if match == nil {
return nil, status.Errorf(codes.NotFound, "record not found for Target (name: %s resource_id: %s)", m.Name, m.ResourceId)
}
return s.targetGet(dbTxn, memTxn, &vagrant_plugin_sdk.Ref_Target{
ResourceId: match.Id,
})
}
func (s *State) targetList(
memTxn *memdb.Txn,
) ([]*vagrant_plugin_sdk.Ref_Target, error) {
iter, err := memTxn.Get(targetIndexTableName, targetIndexIdIndexName+"_prefix", "")
if err != nil {
return nil, err
}
var result []*vagrant_plugin_sdk.Ref_Target
for {
next := iter.Next()
if next == nil {
break
}
result = append(result, &vagrant_plugin_sdk.Ref_Target{
ResourceId: next.(*targetIndexRecord).Id,
Name: next.(*targetIndexRecord).Name,
Project: &vagrant_plugin_sdk.Ref_Project{
ResourceId: next.(*targetIndexRecord).ProjectId,
},
})
}
return result, nil
}
func (s *State) targetPut(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
value *vagrant_server.Target,
) (err error) {
s.log.Trace("storing target", "target", value, "project",
value.GetProject(), "basis", value.GetProject().GetBasis())
p, err := s.projectGet(dbTxn, memTxn, value.Project)
if err != nil {
s.log.Error("failed to locate project for target", "target", value,
"project", p, "error", err)
return
}
if value.ResourceId == "" {
// If no resource id is provided, try to find the target based on the name and project
foundTarget, erro := s.targetFind(dbTxn, memTxn, value)
// If an invalid return code is returned from find then an error occured
if _, ok := status.FromError(erro); !ok {
return erro
}
if foundTarget != nil {
// Make sure the config doesn't get merged - we want the config to overwrite the old config
finalConfig := proto.Clone(value.Configuration)
// Merge found target with provided target
proto.Merge(value, foundTarget)
value.ResourceId = foundTarget.ResourceId
value.Uuid = foundTarget.Uuid
value.Configuration = finalConfig.(*vagrant_plugin_sdk.Args_ConfigData)
} else {
s.log.Trace("target has no resource id and could not find matching target, assuming new target",
"target", value)
if value.ResourceId, err = s.newResourceId(); err != nil {
s.log.Error("failed to create resource id for target", "target", value,
"error", err)
return
}
}
if value.Uuid == "" {
s.log.Trace("target has no uuid assigned, assigning...", "target", value)
uID, err := uuid.NewUUID()
if err != nil {
// Set a public ID on the target before creating
func (t *Target) BeforeSave(tx *gorm.DB) error {
if t.ResourceId == nil {
if err := t.setId(); err != nil {
return err
}
value.Uuid = uID.String()
}
}
s.log.Trace("storing target to db", "target", value)
id := s.targetId(value)
b := dbTxn.Bucket(targetBucket)
if err = dbPut(b, id, value); err != nil {
s.log.Error("failed to store target in db", "target", value, "error", err)
return
}
s.log.Trace("indexing target", "target", value)
if err = s.targetIndexSet(memTxn, id, value); err != nil {
s.log.Error("failed to index target", "target", value, "error", err)
return
}
s.log.Trace("adding target to project", "target", value, "project", p)
pp := serverptypes.Project{Project: p}
if pp.AddTarget(value) {
s.log.Trace("target added to project, updating project", "project", p)
if err = s.projectPut(dbTxn, memTxn, p); err != nil {
s.log.Error("failed to update project", "project", p, "error", err)
return
}
} else {
s.log.Trace("target already exists in project", "target", value, "project", p)
}
return
}
func (s *State) targetGet(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
ref *vagrant_plugin_sdk.Ref_Target,
) (*vagrant_server.Target, error) {
var result vagrant_server.Target
b := dbTxn.Bucket(targetBucket)
return &result, dbGet(b, s.targetIdByRef(ref), &result)
}
func (s *State) targetDelete(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
ref *vagrant_plugin_sdk.Ref_Target,
) (err error) {
p, err := s.projectGet(dbTxn, memTxn, &vagrant_plugin_sdk.Ref_Project{ResourceId: ref.Project.ResourceId})
if err != nil {
return
}
if err = dbTxn.Bucket(targetBucket).Delete(s.targetIdByRef(ref)); err != nil {
return
}
if err = memTxn.Delete(targetIndexTableName, s.newTargetIndexRecordByRef(ref)); err != nil {
return
}
pp := serverptypes.Project{Project: p}
if pp.DeleteTargetRef(ref) {
if err = s.projectPut(dbTxn, memTxn, pp.Project); err != nil {
return
}
}
return
}
func (s *State) targetIndexSet(txn *memdb.Txn, id []byte, value *vagrant_server.Target) error {
return txn.Insert(targetIndexTableName, s.newTargetIndexRecord(value))
}
func (s *State) targetIndexInit(dbTxn *bolt.Tx, memTxn *memdb.Txn) error {
bucket := dbTxn.Bucket(targetBucket)
return bucket.ForEach(func(k, v []byte) error {
var value vagrant_server.Target
if err := proto.Unmarshal(v, &value); err != nil {
return err
}
if err := s.targetIndexSet(memTxn, k, &value); err != nil {
if err := t.validate(tx); err != nil {
return err
}
return nil
})
}
func targetIndexSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: targetIndexTableName,
Indexes: map[string]*memdb.IndexSchema{
targetIndexIdIndexName: {
Name: targetIndexIdIndexName,
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Id",
Lowercase: false,
},
},
targetIndexNameIndexName: {
Name: targetIndexNameIndexName,
AllowMissing: false,
Unique: false,
Indexer: &memdb.StringFieldIndex{
Field: "Name",
Lowercase: true,
},
},
targetIndexProjectIndexName: {
Name: targetIndexProjectIndexName,
AllowMissing: false,
Unique: false,
Indexer: &memdb.StringFieldIndex{
Field: "ProjectId",
Lowercase: true,
},
},
targetIndexUuidName: {
Name: targetIndexUuidName,
AllowMissing: true,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Uuid",
Lowercase: true,
},
},
},
func (t *Target) validate(tx *gorm.DB) error {
err := validation.ValidateStruct(t,
validation.Field(&t.Name,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Target{}).
Where(&Target{Name: t.Name, ProjectID: t.ProjectID}).
Not(&Target{Model: gorm.Model{ID: t.ID}}),
),
),
),
validation.Field(&t.ResourceId,
validation.Required,
validation.By(
checkUnique(
tx.Model(&Target{}).
Where(&Target{ResourceId: t.ResourceId}).
Not(&Target{Model: gorm.Model{ID: t.ID}}),
),
),
),
validation.Field(&t.Uuid,
validation.When(t.Uuid != nil,
validation.By(
checkUnique(
tx.Model(&Target{}).
Where(&Target{Uuid: t.Uuid}).
Not(&Target{Model: gorm.Model{ID: t.ID}}),
),
),
),
),
// validation.Field(&t.ProjectID, validation.Required), TODO(spox): why are these empty?
)
if err != nil {
return err
}
return nil
}
const (
targetIndexIdIndexName = "id"
targetIndexNameIndexName = "name"
targetIndexProjectIndexName = "project"
targetIndexUuidName = "uuid"
targetIndexTableName = "target-index"
func (t *Target) setId() error {
id, err := server.Id()
if err != nil {
return err
}
t.ResourceId = &id
return nil
}
// Convert target to reference protobuf message
func (t *Target) ToProtoRef() *vagrant_plugin_sdk.Ref_Target {
if t == nil {
return nil
}
var ref vagrant_plugin_sdk.Ref_Target
err := decode(t, &ref)
if err != nil {
panic("failed to decode target to ref: " + err.Error())
}
return &ref
}
// Convert target to protobuf message
func (t *Target) ToProto() *vagrant_server.Target {
if t == nil {
return nil
}
var target vagrant_server.Target
err := decode(t, &target)
if err != nil {
panic("failed to decode target: " + err.Error())
}
return &target
}
// Load a Target from reference protobuf message
func (s *State) TargetFromProtoRef(
ref *vagrant_plugin_sdk.Ref_Target,
) (*Target, error) {
if ref == nil {
return nil, ErrEmptyProtoArgument
}
if ref.ResourceId == "" {
return nil, gorm.ErrRecordNotFound
}
var target Target
result := s.search().Preload("Project.Basis").First(&target,
&Target{ResourceId: &ref.ResourceId},
)
if result.Error != nil {
return nil, result.Error
}
return &target, nil
}
func (s *State) TargetFromProtoRefFuzzy(
ref *vagrant_plugin_sdk.Ref_Target,
) (*Target, error) {
target, err := s.TargetFromProtoRef(ref)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
if ref.Project == nil {
return nil, ErrMissingProtoParent
}
if ref.Name == "" {
return nil, gorm.ErrRecordNotFound
}
target = &Target{}
result := s.search().
Joins("Project", &Project{ResourceId: &ref.Project.ResourceId}).
Preload("Project.Basis").
First(target, &Target{Name: &ref.Name})
if result.Error != nil {
return nil, result.Error
}
return target, nil
}
// Load a Target from protobuf message
func (s *State) TargetFromProto(
t *vagrant_server.Target,
) (*Target, error) {
target, err := s.TargetFromProtoRef(
&vagrant_plugin_sdk.Ref_Target{
ResourceId: t.ResourceId,
},
)
if err != nil {
return nil, err
}
return target, nil
}
func (s *State) TargetFromProtoFuzzy(
t *vagrant_server.Target,
) (*Target, error) {
target, err := s.TargetFromProto(t)
if err == nil {
return target, nil
}
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
if t.Project == nil {
return nil, ErrMissingProtoParent
}
target = &Target{}
query := &Target{Name: &t.Name}
tx := s.db.
Preload("Project",
s.db.Where(
&Project{ResourceId: &t.Project.ResourceId},
),
)
if t.Name != "" {
query.Name = &t.Name
}
if t.Uuid != "" {
query.Uuid = &t.Uuid
tx = tx.Or("uuid LIKE ?", fmt.Sprintf("%%%s%%", t.Uuid))
}
result := s.search().Joins("Project").
Preload("Project.Basis").
Where("Project.resource_id = ?", t.Project.ResourceId).
First(target, query)
if result.Error != nil {
return nil, result.Error
}
return target, nil
}
// Get a target record using a reference protobuf message
func (s *State) TargetGet(
ref *vagrant_plugin_sdk.Ref_Target,
) (*vagrant_server.Target, error) {
t, err := s.TargetFromProtoRef(ref)
if err != nil {
return nil, lookupErrorToStatus("target", err)
}
return t.ToProto(), nil
}
// List all target records
func (s *State) TargetList() ([]*vagrant_plugin_sdk.Ref_Target, error) {
var targets []Target
result := s.search().Find(&targets)
if result.Error != nil {
return nil, lookupErrorToStatus("targets", result.Error)
}
trefs := make([]*vagrant_plugin_sdk.Ref_Target, len(targets))
for i, t := range targets {
trefs[i] = t.ToProtoRef()
}
return trefs, nil
}
// Delete a target by reference protobuf message
func (s *State) TargetDelete(
t *vagrant_plugin_sdk.Ref_Target,
) error {
target, err := s.TargetFromProtoRef(t)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return nil
}
if err != nil {
return lookupErrorToStatus("target", err)
}
result := s.db.Delete(target)
if result.Error != nil {
return deleteErrorToStatus("target", result.Error)
}
return nil
}
// Store a Target
func (s *State) TargetPut(
t *vagrant_server.Target,
) (*vagrant_server.Target, error) {
target, err := s.TargetFromProto(t)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, lookupErrorToStatus("target", err)
}
// Make sure we don't have a nil
if err != nil {
target = &Target{}
}
s.log.Info("pre-decode our target project", "project", target.Project)
err = s.softDecode(t, target)
if err != nil {
return nil, saveErrorToStatus("target", err)
}
s.log.Info("post-decode our target project", "project", target.Project)
if target.Project == nil {
panic("stop")
}
result := s.db.Save(target)
s.log.Info("after save target project status", "project", target.Project, "error", err)
if result.Error != nil {
return nil, saveErrorToStatus("target", result.Error)
}
return target.ToProto(), nil
}
// Find a Target
func (s *State) TargetFind(
t *vagrant_server.Target,
) (*vagrant_server.Target, error) {
target, err := s.TargetFromProtoFuzzy(t)
if err != nil {
return nil, lookupErrorToStatus("target", err)
}
return target.ToProto(), nil
}
var (
_ scope = (*Target)(nil)
)
type targetIndexRecord struct {
Id string // Resource ID
Name string // Target Name
ProjectId string // Project Resource ID
Uuid string // Target UUID
}
func (s *State) newTargetIndexRecord(m *vagrant_server.Target) *targetIndexRecord {
var projectResourceId string
if m.Project != nil {
projectResourceId = m.Project.ResourceId
}
i := &targetIndexRecord{
Id: m.ResourceId,
Name: m.Name,
ProjectId: projectResourceId,
Uuid: m.Uuid,
}
return i
}
func (s *State) newTargetIndexRecordByRef(ref *vagrant_plugin_sdk.Ref_Target) *targetIndexRecord {
var projectResourceId string
if ref.Project != nil {
projectResourceId = ref.Project.ResourceId
}
return &targetIndexRecord{
Id: ref.ResourceId,
Name: ref.Name,
ProjectId: projectResourceId,
}
}
func (s *State) targetId(m *vagrant_server.Target) []byte {
return []byte(m.ResourceId)
}
func (s *State) targetIdByRef(m *vagrant_plugin_sdk.Ref_Target) []byte {
return []byte(m.ResourceId)
}

View File

@ -11,7 +11,6 @@ import (
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
serverptypes "github.com/hashicorp/vagrant/internal/server/ptypes"
)
func TestTarget(t *testing.T) {
@ -36,13 +35,11 @@ func TestTarget(t *testing.T) {
defer s.Close()
projectRef := testProject(t, s)
resourceId := "AbCdE"
// Set
err := s.TargetPut(serverptypes.TestTarget(t, &vagrant_server.Target{
ResourceId: resourceId,
result, err := s.TargetPut(&vagrant_server.Target{
Project: projectRef,
Name: "test",
}))
})
require.NoError(err)
// Ensure there is one entry
@ -51,12 +48,14 @@ func TestTarget(t *testing.T) {
require.Len(resp, 1)
// Try to insert duplicate entry
err = s.TargetPut(serverptypes.TestTarget(t, &vagrant_server.Target{
ResourceId: resourceId,
doubleResult, err := s.TargetPut(&vagrant_server.Target{
ResourceId: result.ResourceId,
Project: projectRef,
Name: "test",
}))
})
require.NoError(err)
require.Equal(doubleResult.ResourceId, result.ResourceId)
require.Equal(doubleResult.Project, result.Project)
// Ensure there is still one entry
resp, err = s.TargetList()
@ -64,11 +63,11 @@ func TestTarget(t *testing.T) {
require.Len(resp, 1)
// Try to insert duplicate entry by just name and project
err = s.TargetPut(serverptypes.TestTarget(t, &vagrant_server.Target{
_, err = s.TargetPut(&vagrant_server.Target{
Project: projectRef,
Name: "test",
}))
require.NoError(err)
})
require.Error(err)
// Ensure there is still one entry
resp, err = s.TargetList()
@ -78,9 +77,8 @@ func TestTarget(t *testing.T) {
// Try to insert duplicate config
key, _ := anypb.New(&wrapperspb.StringValue{Value: "vm"})
value, _ := anypb.New(&wrapperspb.StringValue{Value: "value"})
err = s.TargetPut(serverptypes.TestTarget(t, &vagrant_server.Target{
Project: projectRef,
Name: "test",
_, err = s.TargetPut(&vagrant_server.Target{
ResourceId: result.ResourceId,
Configuration: &vagrant_plugin_sdk.Args_ConfigData{
Data: &vagrant_plugin_sdk.Args_Hash{
Entries: []*vagrant_plugin_sdk.Args_HashEntry{
@ -91,11 +89,10 @@ func TestTarget(t *testing.T) {
},
},
},
}))
})
require.NoError(err)
err = s.TargetPut(serverptypes.TestTarget(t, &vagrant_server.Target{
Project: projectRef,
Name: "test",
_, err = s.TargetPut(&vagrant_server.Target{
ResourceId: result.ResourceId,
Configuration: &vagrant_plugin_sdk.Args_ConfigData{
Data: &vagrant_plugin_sdk.Args_Hash{
Entries: []*vagrant_plugin_sdk.Args_HashEntry{
@ -106,7 +103,7 @@ func TestTarget(t *testing.T) {
},
},
},
}))
})
require.NoError(err)
// Ensure there is still one entry
@ -115,9 +112,11 @@ func TestTarget(t *testing.T) {
require.Len(resp, 1)
// Ensure the config did not merge
targetResp, err := s.TargetGet(&vagrant_plugin_sdk.Ref_Target{
ResourceId: resourceId,
ResourceId: result.ResourceId,
})
require.NoError(err)
require.NotNil(targetResp.Configuration)
require.NotNil(targetResp.Configuration.Data)
require.Len(targetResp.Configuration.Data.Entries, 1)
vmAny := targetResp.Configuration.Data.Entries[0].Value
vmString := wrapperspb.StringValue{}
@ -127,11 +126,11 @@ func TestTarget(t *testing.T) {
// Get exact
{
resp, err := s.TargetGet(&vagrant_plugin_sdk.Ref_Target{
ResourceId: resourceId,
ResourceId: result.ResourceId,
})
require.NoError(err)
require.NotNil(resp)
require.Equal(resp.ResourceId, resourceId)
require.Equal(resp.ResourceId, result.ResourceId)
}
@ -150,18 +149,16 @@ func TestTarget(t *testing.T) {
defer s.Close()
projectRef := testProject(t, s)
resourceId := "AbCdE"
// Set
err := s.TargetPut(serverptypes.TestTarget(t, &vagrant_server.Target{
ResourceId: resourceId,
result, err := s.TargetPut(&vagrant_server.Target{
Project: projectRef,
Name: "test",
}))
})
require.NoError(err)
// Read
resp, err := s.TargetGet(&vagrant_plugin_sdk.Ref_Target{
ResourceId: resourceId,
ResourceId: result.ResourceId,
})
require.NoError(err)
require.NotNil(resp)
@ -169,7 +166,7 @@ func TestTarget(t *testing.T) {
// Delete
{
err := s.TargetDelete(&vagrant_plugin_sdk.Ref_Target{
ResourceId: resourceId,
ResourceId: result.ResourceId,
Project: projectRef,
})
require.NoError(err)
@ -178,7 +175,7 @@ func TestTarget(t *testing.T) {
// Read
{
_, err := s.TargetGet(&vagrant_plugin_sdk.Ref_Target{
ResourceId: resourceId,
ResourceId: result.ResourceId,
})
require.Error(err)
require.Equal(codes.NotFound, status.Code(err))
@ -199,33 +196,30 @@ func TestTarget(t *testing.T) {
defer s.Close()
projectRef := testProject(t, s)
resourceId := "AbCdE"
// Set
err := s.TargetPut(serverptypes.TestTarget(t, &vagrant_server.Target{
ResourceId: resourceId,
result, err := s.TargetPut(&vagrant_server.Target{
Project: projectRef,
Name: "test",
}))
})
require.NoError(err)
// Find by resource id
{
resp, err := s.TargetFind(&vagrant_server.Target{
ResourceId: resourceId,
ResourceId: result.ResourceId,
})
require.NoError(err)
require.NotNil(resp)
require.Equal(resp.ResourceId, resourceId)
require.Equal(resp.ResourceId, result.ResourceId)
}
// Find by resource name
// Find by resource name without project
{
resp, err := s.TargetFind(&vagrant_server.Target{
Name: "test",
})
require.NoError(err)
require.NotNil(resp)
require.Equal(resp.ResourceId, resourceId)
require.Error(err)
require.Nil(resp)
}
// Find by resource name+project
@ -235,7 +229,7 @@ func TestTarget(t *testing.T) {
})
require.NoError(err)
require.NotNil(resp)
require.Equal(resp.ResourceId, resourceId)
require.Equal(resp.ResourceId, result.ResourceId)
}
// Don't find nonexistent project
@ -243,8 +237,8 @@ func TestTarget(t *testing.T) {
resp, err := s.TargetFind(&vagrant_server.Target{
Name: "test", Project: &vagrant_plugin_sdk.Ref_Project{ResourceId: "idontexist"},
})
require.Error(err)
require.Nil(resp)
require.Error(err)
}
// Don't find just by project

View File

@ -1,119 +1,187 @@
package state
import (
"bytes"
"io/ioutil"
"os"
"path/filepath"
"github.com/glebarez/sqlite"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
serverptypes "github.com/hashicorp/vagrant/internal/server/ptypes"
"github.com/imdario/mergo"
"github.com/mitchellh/go-testing-interface"
"github.com/stretchr/testify/require"
bolt "go.etcd.io/bbolt"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// TestState returns an initialized State for testing.
func TestState(t testing.T) *State {
result, err := New(hclog.L(), testDB(t))
require.NoError(t, err)
return result
}
// TestStateReinit reinitializes the state by pretending to restart
// the server with the database associated with this state. This can be
// used to test index init logic.
//
// This safely copies the entire DB so the old state can continue running
// with zero impact.
func TestStateReinit(t testing.T, s *State) *State {
// Copy the old database to a brand new path
td, err := ioutil.TempDir("", "test")
require.NoError(t, err)
t.Cleanup(func() { os.RemoveAll(td) })
path := filepath.Join(td, "test.db")
// Start db copy
require.NoError(t, s.db.View(func(tx *bolt.Tx) error {
return tx.CopyFile(path, 0600)
}))
// Open the new DB
db, err := bolt.Open(path, 0600, nil)
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
// Init new state
result, err := New(hclog.L(), db)
require.NoError(t, err)
return result
}
// TestStateRestart closes the given state and restarts it against the
// same DB file. Unlike TestStateReinit, this does not copy the data and
// the old state is no longer usable.
func TestStateRestart(t testing.T, s *State) (*State, error) {
path := s.db.Path()
require.NoError(t, s.Close())
// Open the new DB
db, err := bolt.Open(path, 0600, nil)
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
// Init new state
return New(hclog.L(), db)
}
func testDB(t testing.T) *bolt.DB {
t.Helper()
// Temporary directory for the database
td, err := ioutil.TempDir("", "test")
require.NoError(t, err)
t.Cleanup(func() { os.RemoveAll(td) })
var buf bytes.Buffer
l := hclog.New(&hclog.LoggerOptions{
Name: "test",
Level: hclog.Trace,
Output: &buf,
IncludeLocation: true,
})
// Create the DB
db, err := bolt.Open(filepath.Join(td, "test.db"), 0600, nil)
t.Cleanup(func() {
t.Log(buf.String())
})
result, err := New(l, testDB(t))
require.NoError(t, err)
t.Cleanup(func() { db.Close() })
return result
}
// // TestStateReinit reinitializes the state by pretending to restart
// // the server with the database associated with this state. This can be
// // used to test index init logic.
// //
// // This safely copies the entire DB so the old state can continue running
// // with zero impact.
// func TestStateReinit(t testing.T, s *State) *State {
// // Copy the old database to a brand new path
// td, err := ioutil.TempDir("", "test")
// require.NoError(t, err)
// t.Cleanup(func() { os.RemoveAll(td) })
// path := filepath.Join(td, "test.db")
// // Start db copy
// require.NoError(t, s.db.View(func(tx *bolt.Tx) error {
// return tx.CopyFile(path, 0600)
// }))
// // Open the new DB
// db, err := bolt.Open(path, 0600, nil)
// require.NoError(t, err)
// t.Cleanup(func() { db.Close() })
// // Init new state
// result, err := New(hclog.L(), db)
// require.NoError(t, err)
// return result
// }
// // TestStateRestart closes the given state and restarts it against the
// // same DB file. Unlike TestStateReinit, this does not copy the data and
// // the old state is no longer usable.
// func TestStateRestart(t testing.T, s *State) (*State, error) {
// path := s.db.Path()
// require.NoError(t, s.Close())
// // Open the new DB
// db, err := bolt.Open(path, 0600, nil)
// require.NoError(t, err)
// t.Cleanup(func() { db.Close() })
// // Init new state
// return New(hclog.L(), db)
// }
func testDB(t testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(""), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err)
t.Cleanup(func() {
dbconn, err := db.DB()
if err == nil {
dbconn.Close()
}
})
return db
}
// TestBasis creates the basis in the DB.
func testBasis(t testing.T, s *State) *vagrant_plugin_sdk.Ref_Basis {
t.Helper()
td := testTempDir(t)
s.BasisPut(serverptypes.TestBasis(t, &vagrant_server.Basis{
ResourceId: "test-basis",
Path: td,
Name: "test-basis",
}))
return &vagrant_plugin_sdk.Ref_Basis{
ResourceId: "test-basis",
Path: td,
Name: "test-basis",
name := filepath.Base(td)
b := &Basis{
Name: &name,
Path: &td,
}
result := s.db.Save(b)
require.NoError(t, result.Error)
return b.ToProtoRef()
}
func testProject(t testing.T, s *State) *vagrant_plugin_sdk.Ref_Project {
t.Helper()
basisRef := testBasis(t, s)
s.ProjectPut(serverptypes.TestProject(t, &vagrant_server.Project{
ResourceId: "test-project",
Basis: basisRef,
Path: "idontexist",
Name: "test-project",
}))
return &vagrant_plugin_sdk.Ref_Project{
ResourceId: "test-project",
Path: "idontexist",
Name: "test-project",
Basis: basisRef,
b, err := s.BasisFromProtoRef(basisRef)
require.NoError(t, err)
td := testTempDir(t)
name := filepath.Base(td)
p := &Project{
Name: &name,
Path: &td,
Basis: b,
}
result := s.db.Save(p)
require.NoError(t, result.Error)
return p.ToProtoRef()
}
func testRunner(t testing.T, s *State, src *vagrant_server.Runner) *vagrant_server.Runner {
t.Helper()
if src == nil {
src = &vagrant_server.Runner{}
}
id, err := server.Id()
require.NoError(t, err)
base := &vagrant_server.Runner{Id: id}
require.NoError(t, mergo.Merge(src, base))
var runner Runner
require.NoError(t, s.decode(src, &runner))
result := s.db.Save(&runner)
require.NoError(t, result.Error)
return runner.ToProto()
}
func testJob(t testing.T, src *vagrant_server.Job) *vagrant_server.Job {
t.Helper()
require.NoError(t, mergo.Merge(src,
&vagrant_server.Job{
TargetRunner: &vagrant_server.Ref_Runner{
Target: &vagrant_server.Ref_Runner_Any{
Any: &vagrant_server.Ref_RunnerAny{},
},
},
DataSource: &vagrant_server.Job_DataSource{
Source: &vagrant_server.Job_DataSource_Local{
Local: &vagrant_server.Job_Local{},
},
},
Operation: &vagrant_server.Job_Noop_{
Noop: &vagrant_server.Job_Noop{},
},
},
))
return src
}
func testTempDir(t testing.T) string {
t.Helper()
dir, err := ioutil.TempDir("", "vagrant-test")
require.NoError(t, err)
t.Cleanup(func() { os.RemoveAll(dir) })

View File

@ -0,0 +1,68 @@
package state
import (
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
"gorm.io/gorm"
)
type VagrantfileFormat uint8
const (
JSON VagrantfileFormat = VagrantfileFormat(vagrant_server.Vagrantfile_JSON)
HCL = VagrantfileFormat(vagrant_server.Vagrantfile_HCL)
RUBY = VagrantfileFormat(vagrant_server.Vagrantfile_RUBY)
)
type Vagrantfile struct {
gorm.Model
Format VagrantfileFormat
Unfinalized *ProtoRaw
Finalized *ProtoRaw
Raw []byte
Path string
}
func init() {
models = append(models, &Vagrantfile{})
}
func (v *Vagrantfile) ToProto() *vagrant_server.Vagrantfile {
if v == nil {
return nil
}
return &vagrant_server.Vagrantfile{
Format: vagrant_server.Vagrantfile_Format(v.Format),
Raw: v.Raw,
Path: &vagrant_plugin_sdk.Args_Path{
Path: v.Path,
},
Unfinalized: v.Unfinalized.Message.(*vagrant_plugin_sdk.Args_Hash),
Finalized: v.Finalized.Message.(*vagrant_plugin_sdk.Args_Hash),
}
}
func (v *Vagrantfile) UpdateFromProto(vf *vagrant_server.Vagrantfile) *Vagrantfile {
v.Format = VagrantfileFormat(vf.Format)
v.Unfinalized = &ProtoRaw{Message: vf.Unfinalized}
v.Finalized = &ProtoRaw{Message: vf.Finalized}
v.Raw = vf.Raw
v.Path = vf.Path.Path
return v
}
func (s *State) VagrantfileFromProto(v *vagrant_server.Vagrantfile) *Vagrantfile {
file := &Vagrantfile{
Format: VagrantfileFormat(v.Format),
Unfinalized: &ProtoRaw{Message: v.Unfinalized},
Finalized: &ProtoRaw{Message: v.Finalized},
Raw: v.Raw,
}
if v.Path != nil {
file.Path = v.Path.Path
}
return file
}

View File

@ -0,0 +1,31 @@
package state
import (
"github.com/go-ozzo/ozzo-validation/v4"
"gorm.io/gorm"
)
type ValidationCode string
const (
VALIDATION_UNIQUE ValidationCode = "unique"
)
func checkUnique(tx *gorm.DB) validation.RuleFunc {
return func(value interface{}) error {
var count int64
result := tx.Count(&count)
if result.Error != nil {
return validation.NewInternalError(result.Error)
}
if count > 0 {
return validation.NewError(
string(VALIDATION_UNIQUE),
"must be unique",
)
}
return nil
}
}