142 lines
3.8 KiB
Go
142 lines
3.8 KiB
Go
package state
|
|
|
|
import (
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
bolt "go.etcd.io/bbolt"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
var (
|
|
// sysBucket stores system-related information.
|
|
sysBucket = []byte("system")
|
|
|
|
// sysVersionKey stores the version of the data that is stored.
|
|
// This is used for data migration.
|
|
sysVersionKey = []byte("version")
|
|
)
|
|
|
|
//!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
|
//
|
|
// DB Version
|
|
//
|
|
// THIS SHOULD BE CHANGED WITH EXTREME CAUTION. Changing this will force users
|
|
// to perform a db upgrade when they upgrade their Vagrant version.
|
|
//
|
|
//!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
|
const (
|
|
dbVersion int64 = 1
|
|
)
|
|
|
|
func init() {
|
|
dbBuckets = append(dbBuckets, sysBucket)
|
|
}
|
|
|
|
// dbInit sets up the database. This should be called once on all new
|
|
// DB handles before accepting API calls. It is safe to be called multiple
|
|
// times.
|
|
func dbInit(db *bolt.DB) error {
|
|
return db.Update(func(tx *bolt.Tx) error {
|
|
// Create all our buckets
|
|
for _, b := range dbBuckets {
|
|
if _, err := tx.CreateBucketIfNotExists(b); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Check our data version
|
|
sys := tx.Bucket(sysBucket)
|
|
vsnRaw := sys.Get(sysVersionKey)
|
|
|
|
// Initialize the version with our current version if it isn't set.
|
|
if len(vsnRaw) == 0 {
|
|
if err := sys.Put(sysVersionKey, []byte(strconv.FormatInt(dbVersion, 10))); err != nil {
|
|
return status.Errorf(codes.Internal,
|
|
"failed to write initial database version: %s", err)
|
|
}
|
|
} else {
|
|
vsn, err := strconv.ParseInt(string(vsnRaw), 10, 64)
|
|
if err != nil {
|
|
return status.Errorf(codes.Internal,
|
|
"failed to read database version: %s", err)
|
|
}
|
|
|
|
if vsn != dbVersion {
|
|
return status.Errorf(codes.FailedPrecondition, strings.TrimSpace(`
|
|
The database version on disk does not match the server version.
|
|
|
|
The server cannot safely read this data. Please upgrade or downgrade your server
|
|
to a version that is capable of reading this data version. You can find this
|
|
information on the Vagrant website.
|
|
|
|
On-disk data version: %d
|
|
Server data version: %d
|
|
|
|
`), vsn, dbVersion)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// dbPut is a helper to insert a proto.Message into a bucket for the given id.
|
|
// Any errors are automatically wrapped into a gRPC status error so they can
|
|
// be sent directly back.
|
|
func dbPut(b *bolt.Bucket, id []byte, msg proto.Message) error {
|
|
enc, err := proto.Marshal(msg)
|
|
if err != nil {
|
|
return status.Errorf(codes.Internal, "failed to encode data: %s", err)
|
|
}
|
|
|
|
if err := b.Put(id, enc); err != nil {
|
|
return status.Errorf(codes.Aborted, "failed to write data: %s", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// dbGet is a helper to get a single proto.Message from a bucket. Errors
|
|
// are guaranteed to be in gRPC status format.
|
|
func dbGet(b *bolt.Bucket, id []byte, msg proto.Message) error {
|
|
raw := b.Get(id)
|
|
if raw == nil {
|
|
return status.Errorf(codes.NotFound, "record not found for ID: %s", id)
|
|
}
|
|
|
|
if err := proto.Unmarshal(raw, msg); err != nil {
|
|
return status.Errorf(codes.Internal, "failed to decode data: %s", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// dbUpsert is a helper to upsert a message. The update boolean will cause
|
|
// this to error if the ID is not found. This reflects our API behavior for
|
|
// upserts so that we don't let the end user pick any ID.
|
|
func dbUpsert(b *bolt.Bucket, update bool, id []byte, msg proto.Message) error {
|
|
// If we're updating, the ID must exist
|
|
if update && b.Get([]byte(id)) == nil {
|
|
return status.Errorf(codes.NotFound, "record not found for ID: %s", id)
|
|
}
|
|
|
|
// Insert
|
|
return dbPut(b, id, msg)
|
|
}
|
|
|
|
func dbCount(b *bolt.DB, tableName string) int {
|
|
count := 0
|
|
b.View(func(tx *bolt.Tx) error {
|
|
bucket := tx.Bucket([]byte(tableName))
|
|
bucket.ForEach(func(k, v []byte) error {
|
|
count++
|
|
return nil
|
|
})
|
|
return nil
|
|
})
|
|
return count
|
|
}
|