2022-04-25 12:26:23 -05:00

397 lines
9.7 KiB
Go

package state
import (
"github.com/golang/protobuf/proto"
"github.com/google/uuid"
"github.com/hashicorp/go-memdb"
bolt "go.etcd.io/bbolt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/hashicorp/vagrant-plugin-sdk/proto/vagrant_plugin_sdk"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
serverptypes "github.com/hashicorp/vagrant/internal/server/ptypes"
)
var targetBucket = []byte("target")
func init() {
dbBuckets = append(dbBuckets, targetBucket)
dbIndexers = append(dbIndexers, (*State).targetIndexInit)
schemas = append(schemas, targetIndexSchema)
}
func (s *State) TargetFind(m *vagrant_server.Target) (*vagrant_server.Target, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
var result *vagrant_server.Target
err := s.db.View(func(dbTxn *bolt.Tx) error {
var err error
result, err = s.targetFind(dbTxn, memTxn, m)
return err
})
return result, err
}
func (s *State) TargetPut(target *vagrant_server.Target) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) error {
return s.targetPut(dbTxn, memTxn, target)
})
if err == nil {
memTxn.Commit()
}
return err
}
func (s *State) TargetDelete(ref *vagrant_plugin_sdk.Ref_Target) error {
memTxn := s.inmem.Txn(true)
defer memTxn.Abort()
err := s.db.Update(func(dbTxn *bolt.Tx) error {
return s.targetDelete(dbTxn, memTxn, ref)
})
if err == nil {
memTxn.Commit()
}
return err
}
func (s *State) TargetGet(ref *vagrant_plugin_sdk.Ref_Target) (*vagrant_server.Target, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
var result *vagrant_server.Target
err := s.db.View(func(dbTxn *bolt.Tx) error {
var err error
result, err = s.targetGet(dbTxn, memTxn, ref)
return err
})
return result, err
}
func (s *State) TargetList() ([]*vagrant_plugin_sdk.Ref_Target, error) {
memTxn := s.inmem.Txn(false)
defer memTxn.Abort()
return s.targetList(memTxn)
}
func (s *State) targetFind(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
m *vagrant_server.Target,
) (*vagrant_server.Target, error) {
var match *targetIndexRecord
req := s.newTargetIndexRecord(m)
// Start with the resource id first
if req.Id != "" {
if raw, err := memTxn.First(
targetIndexTableName,
targetIndexIdIndexName,
req.Id,
); raw != nil && err == nil {
match = raw.(*targetIndexRecord)
}
}
// Try the name + project next
if match == nil && req.Name != "" {
// Match the name first
raw, err := memTxn.Get(
targetIndexTableName,
targetIndexNameIndexName,
req.Name,
)
if err != nil {
return nil, err
}
// Check for matching project next
if req.ProjectId != "" {
for e := raw.Next(); e != nil; e = raw.Next() {
targetIndexEntry := e.(*targetIndexRecord)
if targetIndexEntry.ProjectId == req.ProjectId {
match = targetIndexEntry
break
}
}
} else {
e := raw.Next()
if e != nil {
match = e.(*targetIndexRecord)
}
}
}
// Finally try the uuid
if match == nil && req.Uuid != "" {
if raw, err := memTxn.First(
targetIndexTableName,
targetIndexUuidName,
req.Uuid,
); raw != nil && err == nil {
match = raw.(*targetIndexRecord)
}
}
if match == nil {
return nil, status.Errorf(codes.NotFound, "record not found for Target")
}
return s.targetGet(dbTxn, memTxn, &vagrant_plugin_sdk.Ref_Target{
ResourceId: match.Id,
})
}
func (s *State) targetList(
memTxn *memdb.Txn,
) ([]*vagrant_plugin_sdk.Ref_Target, error) {
iter, err := memTxn.Get(targetIndexTableName, targetIndexIdIndexName+"_prefix", "")
if err != nil {
return nil, err
}
var result []*vagrant_plugin_sdk.Ref_Target
for {
next := iter.Next()
if next == nil {
break
}
result = append(result, &vagrant_plugin_sdk.Ref_Target{
ResourceId: next.(*targetIndexRecord).Id,
Name: next.(*targetIndexRecord).Name,
})
}
return result, nil
}
func (s *State) targetPut(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
value *vagrant_server.Target,
) (err error) {
s.log.Trace("storing target", "target", value, "project",
value.GetProject(), "basis", value.GetProject().GetBasis())
p, err := s.projectGet(dbTxn, memTxn, value.Project)
if err != nil {
s.log.Error("failed to locate project for target", "target", value,
"project", p, "error", err)
return
}
if value.ResourceId == "" {
// If no resource id is provided, try to find the target based on the name and project
foundTarget, erro := s.targetFind(dbTxn, memTxn, value)
// If an invalid return code is returned from find then an error occured
if _, ok := status.FromError(erro); !ok {
return erro
}
if foundTarget != nil {
// Merge found target with provided target
proto.Merge(value, foundTarget)
value.ResourceId = foundTarget.ResourceId
value.Uuid = foundTarget.Uuid
} else {
s.log.Trace("target has no resource id and could not find matching target, assuming new target",
"target", value)
if value.ResourceId, err = s.newResourceId(); err != nil {
s.log.Error("failed to create resource id for target", "target", value,
"error", err)
return
}
}
if value.Uuid == "" {
s.log.Trace("target has no uuid assigned, assigning...", "target", value)
uID, err := uuid.NewUUID()
if err != nil {
return err
}
value.Uuid = uID.String()
}
}
s.log.Trace("storing target to db", "target", value)
id := s.targetId(value)
b := dbTxn.Bucket(targetBucket)
if err = dbPut(b, id, value); err != nil {
s.log.Error("failed to store target in db", "target", value, "error", err)
return
}
s.log.Trace("indexing target", "target", value)
if err = s.targetIndexSet(memTxn, id, value); err != nil {
s.log.Error("failed to index target", "target", value, "error", err)
return
}
s.log.Trace("adding target to project", "target", value, "project", p)
pp := serverptypes.Project{Project: p}
if pp.AddTarget(value) {
s.log.Trace("target added to project, updating project", "project", p)
if err = s.projectPut(dbTxn, memTxn, p); err != nil {
s.log.Error("failed to update project", "project", p, "error", err)
return
}
} else {
s.log.Trace("target already exists in project", "target", value, "project", p)
}
return
}
func (s *State) targetGet(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
ref *vagrant_plugin_sdk.Ref_Target,
) (*vagrant_server.Target, error) {
var result vagrant_server.Target
b := dbTxn.Bucket(targetBucket)
return &result, dbGet(b, s.targetIdByRef(ref), &result)
}
func (s *State) targetDelete(
dbTxn *bolt.Tx,
memTxn *memdb.Txn,
ref *vagrant_plugin_sdk.Ref_Target,
) (err error) {
p, err := s.projectGet(dbTxn, memTxn, &vagrant_plugin_sdk.Ref_Project{ResourceId: ref.Project.ResourceId})
if err != nil {
return
}
if err = dbTxn.Bucket(targetBucket).Delete(s.targetIdByRef(ref)); err != nil {
return
}
if err = memTxn.Delete(targetIndexTableName, s.newTargetIndexRecordByRef(ref)); err != nil {
return
}
pp := serverptypes.Project{Project: p}
if pp.DeleteTargetRef(ref) {
if err = s.projectPut(dbTxn, memTxn, p); err != nil {
return
}
}
return
}
func (s *State) targetIndexSet(txn *memdb.Txn, id []byte, value *vagrant_server.Target) error {
return txn.Insert(targetIndexTableName, s.newTargetIndexRecord(value))
}
func (s *State) targetIndexInit(dbTxn *bolt.Tx, memTxn *memdb.Txn) error {
bucket := dbTxn.Bucket(targetBucket)
return bucket.ForEach(func(k, v []byte) error {
var value vagrant_server.Target
if err := proto.Unmarshal(v, &value); err != nil {
return err
}
if err := s.targetIndexSet(memTxn, k, &value); err != nil {
return err
}
return nil
})
}
func targetIndexSchema() *memdb.TableSchema {
return &memdb.TableSchema{
Name: targetIndexTableName,
Indexes: map[string]*memdb.IndexSchema{
targetIndexIdIndexName: {
Name: targetIndexIdIndexName,
AllowMissing: false,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Id",
Lowercase: false,
},
},
targetIndexNameIndexName: {
Name: targetIndexNameIndexName,
AllowMissing: false,
Unique: false,
Indexer: &memdb.StringFieldIndex{
Field: "Name",
Lowercase: true,
},
},
targetIndexProjectIndexName: {
Name: targetIndexProjectIndexName,
AllowMissing: false,
Unique: false,
Indexer: &memdb.StringFieldIndex{
Field: "ProjectId",
Lowercase: true,
},
},
targetIndexUuidName: {
Name: targetIndexUuidName,
AllowMissing: true,
Unique: true,
Indexer: &memdb.StringFieldIndex{
Field: "Uuid",
Lowercase: true,
},
},
},
}
}
const (
targetIndexIdIndexName = "id"
targetIndexNameIndexName = "name"
targetIndexProjectIndexName = "project"
targetIndexUuidName = "uuid"
targetIndexTableName = "target-index"
)
type targetIndexRecord struct {
Id string // Resource ID
Name string // Target Name
ProjectId string // Project Resource ID
Uuid string // Target UUID
}
func (s *State) newTargetIndexRecord(m *vagrant_server.Target) *targetIndexRecord {
var projectResourceId string
if m.Project != nil {
projectResourceId = m.Project.ResourceId
}
i := &targetIndexRecord{
Id: m.ResourceId,
Name: m.Name,
ProjectId: projectResourceId,
Uuid: m.Uuid,
}
return i
}
func (s *State) newTargetIndexRecordByRef(ref *vagrant_plugin_sdk.Ref_Target) *targetIndexRecord {
var projectResourceId string
if ref.Project != nil {
projectResourceId = ref.Project.ResourceId
}
return &targetIndexRecord{
Id: ref.ResourceId,
Name: ref.Name,
ProjectId: projectResourceId,
}
}
func (s *State) targetId(m *vagrant_server.Target) []byte {
return []byte(m.ResourceId)
}
func (s *State) targetIdByRef(m *vagrant_plugin_sdk.Ref_Target) []byte {
return []byte(m.ResourceId)
}