176 lines
3.7 KiB
Go
176 lines
3.7 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package state
|
|
|
|
import (
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"fmt"
|
|
|
|
"github.com/hashicorp/vagrant-plugin-sdk/internal-shared/dynamic"
|
|
"google.golang.org/protobuf/encoding/protojson"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
"gorm.io/datatypes"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/migrator"
|
|
"gorm.io/gorm/schema"
|
|
)
|
|
|
|
// ProtoValue stores a protobuf message in the database
|
|
// as a JSON field so it can be queried. Note that using
|
|
// this type can result in lossy storage depending on
|
|
// types in the message
|
|
type ProtoValue struct {
|
|
Message proto.Message
|
|
}
|
|
|
|
// User consumable data type name
|
|
func (p *ProtoValue) GormDataType() string {
|
|
return datatypes.JSON{}.GormDataType()
|
|
}
|
|
|
|
// Driver consumable data type name
|
|
func (p *ProtoValue) GormDBDataType(db *gorm.DB, field *schema.Field) string {
|
|
return datatypes.JSON{}.GormDBDataType(db, field)
|
|
}
|
|
|
|
// Unmarshals the store value back to original type
|
|
func (p *ProtoValue) Scan(value interface{}) error {
|
|
if value == nil {
|
|
return nil
|
|
}
|
|
|
|
s, ok := value.(string)
|
|
if !ok {
|
|
return fmt.Errorf("failed to unmarshal protobuf value, invalid type (%T)", value)
|
|
}
|
|
|
|
if s == "" {
|
|
return nil
|
|
}
|
|
|
|
v := []byte(s)
|
|
var m anypb.Any
|
|
err := protojson.Unmarshal(v, &m)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, i, err := dynamic.DecodeAny(&m)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
pm, ok := i.(proto.Message)
|
|
if !ok {
|
|
return fmt.Errorf("failed to set unmarshaled proto value, invalid type (%T)", i)
|
|
}
|
|
|
|
p.Message = pm
|
|
|
|
return nil
|
|
}
|
|
|
|
// Marshal the value for storage in the database
|
|
func (p *ProtoValue) Value() (driver.Value, error) {
|
|
if p == nil || p.Message == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
a, err := dynamic.EncodeAny(p.Message)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
j, err := protojson.Marshal(a)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return string(j), nil
|
|
}
|
|
|
|
// ProtoRaw stores a protobuf message in the database
|
|
// as raw bytes. Note that when using this type the
|
|
// contents of the protobuf message cannot be queried
|
|
type ProtoRaw struct {
|
|
Message proto.Message
|
|
}
|
|
|
|
// User consumable data type name
|
|
func (p *ProtoRaw) GormDataType() string {
|
|
return "bytes"
|
|
}
|
|
|
|
// Driver consumable data type name
|
|
func (p *ProtoRaw) GormDBDataType(db *gorm.DB, field *schema.Field) string {
|
|
return "BLOB"
|
|
}
|
|
|
|
// Unmarshals the store value back to original type
|
|
func (p *ProtoRaw) Scan(value interface{}) error {
|
|
if p == nil || value == nil {
|
|
return nil
|
|
}
|
|
|
|
s, ok := value.(string)
|
|
if !ok {
|
|
return fmt.Errorf("failed to unmarshal protobuf raw, invalid type (%T)", value)
|
|
}
|
|
|
|
if s == "" {
|
|
return nil
|
|
}
|
|
|
|
v := []byte(s)
|
|
var a anypb.Any
|
|
err := proto.Unmarshal(v, &a)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, m, err := dynamic.DecodeAny(&a)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
pm, ok := m.(proto.Message)
|
|
if !ok {
|
|
return fmt.Errorf("failed to set unmarshaled proto raw, invalid type (%T)", m)
|
|
}
|
|
|
|
p.Message = pm
|
|
|
|
return nil
|
|
}
|
|
|
|
// Marshal the value for storage in the database
|
|
func (p *ProtoRaw) Value() (driver.Value, error) {
|
|
if p == nil || p.Message == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
m, err := dynamic.EncodeAny(p.Message)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
r, err := proto.Marshal(m)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return string(r), nil
|
|
}
|
|
|
|
var (
|
|
_ sql.Scanner = (*ProtoValue)(nil)
|
|
_ driver.Valuer = (*ProtoValue)(nil)
|
|
_ schema.GormDataTypeInterface = (*ProtoValue)(nil)
|
|
_ migrator.GormDataTypeInterface = (*ProtoValue)(nil)
|
|
_ sql.Scanner = (*ProtoRaw)(nil)
|
|
_ driver.Valuer = (*ProtoRaw)(nil)
|
|
_ schema.GormDataTypeInterface = (*ProtoRaw)(nil)
|
|
_ migrator.GormDataTypeInterface = (*ProtoRaw)(nil)
|
|
)
|