293 lines
7.4 KiB
Go
293 lines
7.4 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package serverclient
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-hclog"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials"
|
|
|
|
"github.com/hashicorp/vagrant/internal/clicontext"
|
|
"github.com/hashicorp/vagrant/internal/protocolversion"
|
|
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
|
|
"github.com/hashicorp/vagrant/internal/serverconfig"
|
|
)
|
|
|
|
type VagrantClient struct {
|
|
vagrant_server.VagrantClient
|
|
conn *grpc.ClientConn
|
|
}
|
|
|
|
func NewVagrantClient(ctx context.Context, log hclog.Logger, addr string) (*VagrantClient, error) {
|
|
log = log.Named("vagrant.client")
|
|
conn, err := Connect(ctx,
|
|
WithAddr(addr),
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &VagrantClient{
|
|
VagrantClient: vagrant_server.NewVagrantClient(conn),
|
|
conn: conn,
|
|
}, nil
|
|
}
|
|
|
|
func WrapVagrantClient(conn *grpc.ClientConn) *VagrantClient {
|
|
return &VagrantClient{
|
|
VagrantClient: vagrant_server.NewVagrantClient(conn),
|
|
conn: conn,
|
|
}
|
|
}
|
|
|
|
func (c *VagrantClient) ServerTarget() string {
|
|
return c.conn.Target()
|
|
}
|
|
|
|
func (c *VagrantClient) Conn() *grpc.ClientConn {
|
|
return c.conn
|
|
}
|
|
|
|
// ConnectOption is used to configure how Vagrant server connection
|
|
// configuration is sourced.
|
|
type ConnectOption func(*connectConfig) error
|
|
|
|
// Connect connects to the Vagrant server. This returns the raw gRPC connection.
|
|
// You'll have to wrap it in NewVagrantClient to get the Vagrant client.
|
|
// We return the raw connection so that you have control over how to close it,
|
|
// and to support potentially alternate services in the future.
|
|
func Connect(ctx context.Context, opts ...ConnectOption) (*grpc.ClientConn, error) {
|
|
// Defaults
|
|
var cfg connectConfig
|
|
cfg.Timeout = 5 * time.Second
|
|
|
|
// Set config
|
|
for _, opt := range opts {
|
|
if err := opt(&cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if cfg.Addr == "" {
|
|
if cfg.Optional {
|
|
return nil, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("no server credentials found")
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, cfg.Timeout)
|
|
defer cancel()
|
|
|
|
// Build our options
|
|
grpcOpts := []grpc.DialOption{
|
|
grpc.WithBlock(),
|
|
grpc.WithUnaryInterceptor(protocolversion.UnaryClientInterceptor(protocolversion.Current())),
|
|
grpc.WithStreamInterceptor(protocolversion.StreamClientInterceptor(protocolversion.Current())),
|
|
}
|
|
|
|
if !cfg.Tls {
|
|
grpcOpts = append(grpcOpts, grpc.WithInsecure())
|
|
} else if cfg.TlsSkipVerify {
|
|
grpcOpts = append(grpcOpts, grpc.WithTransportCredentials(
|
|
credentials.NewTLS(&tls.Config{InsecureSkipVerify: true}),
|
|
))
|
|
}
|
|
|
|
if cfg.Auth {
|
|
token := cfg.Token
|
|
if v := os.Getenv(EnvServerToken); v != "" {
|
|
token = v
|
|
}
|
|
|
|
if token == "" {
|
|
return nil, fmt.Errorf("No token available at the VAGRANT_SERVER_TOKEN environment variable")
|
|
}
|
|
|
|
grpcOpts = append(grpcOpts, grpc.WithPerRPCCredentials(StaticToken(token)))
|
|
}
|
|
|
|
// Connect to this server
|
|
return grpc.DialContext(ctx, cfg.Addr, grpcOpts...)
|
|
}
|
|
|
|
// ContextConfig will return the context configuration for the given connection
|
|
// options.
|
|
func ContextConfig(opts ...ConnectOption) (*clicontext.Config, error) {
|
|
// Setup config
|
|
var cfg connectConfig
|
|
for _, opt := range opts {
|
|
if err := opt(&cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Build it
|
|
return &clicontext.Config{
|
|
Server: serverconfig.Client{
|
|
Address: cfg.Addr,
|
|
Tls: cfg.Tls,
|
|
TlsSkipVerify: cfg.TlsSkipVerify,
|
|
RequireAuth: cfg.Token != "",
|
|
AuthToken: cfg.Token,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
type connectConfig struct {
|
|
Addr string
|
|
Tls bool
|
|
TlsSkipVerify bool
|
|
Auth bool
|
|
Token string
|
|
Optional bool // See Optional func
|
|
Timeout time.Duration
|
|
}
|
|
|
|
func WithAddr(addr string) ConnectOption {
|
|
return func(c *connectConfig) error {
|
|
c.Addr = addr
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// FromEnv sources the connection information from the environment
|
|
// using standard environment variables.
|
|
func FromEnv() ConnectOption {
|
|
return func(c *connectConfig) error {
|
|
if v := os.Getenv(EnvServerAddr); v != "" {
|
|
c.Addr = v
|
|
c.Tls = os.Getenv(EnvServerTls) != ""
|
|
c.TlsSkipVerify = os.Getenv(EnvServerTlsSkipVerify) != ""
|
|
c.Auth = os.Getenv(EnvServerToken) != ""
|
|
}
|
|
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// FromContextConfig loads a specific context config.
|
|
func FromContextConfig(cfg *clicontext.Config) ConnectOption {
|
|
return func(c *connectConfig) error {
|
|
if cfg != nil && cfg.Server.Address != "" {
|
|
c.Addr = cfg.Server.Address
|
|
c.Tls = cfg.Server.Tls
|
|
c.TlsSkipVerify = cfg.Server.TlsSkipVerify
|
|
if cfg.Server.RequireAuth {
|
|
c.Auth = true
|
|
c.Token = cfg.Server.AuthToken
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// FromContext loads the context. This will prefer the given name. If name
|
|
// is empty, we'll respect the VAGRANT_CONTEXT env var followed by the
|
|
// default context.
|
|
func FromContext(st *clicontext.Storage, n string) ConnectOption {
|
|
return func(c *connectConfig) error {
|
|
// Figure out what context to load. We prefer to load a manually
|
|
// specified one. If that isn't set, we prefer the env var. If that
|
|
// isn't set, we load the default.
|
|
if n == "" {
|
|
if v := os.Getenv(EnvContext); v != "" {
|
|
n = v
|
|
} else {
|
|
def, err := st.Default()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
n = def
|
|
}
|
|
}
|
|
|
|
// If we still have no name, then we do nothing. We also accept
|
|
// "-" as a valid name that means "do nothing".
|
|
if n == "" || n == "-" {
|
|
return nil
|
|
}
|
|
|
|
// Load it and set it.
|
|
cfg, err := st.Load(n)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
opt := FromContextConfig(cfg)
|
|
return opt(c)
|
|
}
|
|
}
|
|
|
|
// Auth specifies that this server should require auth and therefore
|
|
// a token should be sourced from the environment and sent.
|
|
func Auth() ConnectOption {
|
|
return func(c *connectConfig) error {
|
|
c.Auth = true
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Optional specifies that getting server connection information is
|
|
// optional. If this is specified and no credentials are found, Connect
|
|
// will return (nil, nil). If this is NOT specified and no credentials are
|
|
// found, it is an error.
|
|
func Optional() ConnectOption {
|
|
return func(c *connectConfig) error {
|
|
c.Optional = true
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Timeout specifies a connection timeout. This defaults to 5 seconds.
|
|
func Timeout(t time.Duration) ConnectOption {
|
|
return func(c *connectConfig) error {
|
|
c.Timeout = t
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Common environment variables.
|
|
const (
|
|
// ServerAddr is the address for the Vagrant server. This should be
|
|
// in the format of "ip:port" for TCP.
|
|
EnvServerAddr = "VAGRANT_SERVER_ADDR"
|
|
|
|
// ServerTls should be any value that strconv.ParseBool parses as
|
|
// true to connect to the server with TLS.
|
|
EnvServerTls = "VAGRANT_SERVER_TLS"
|
|
EnvServerTlsSkipVerify = "VAGRANT_SERVER_TLS_SKIP_VERIFY"
|
|
|
|
// EnvServerToken is the token for authenticated with the server.
|
|
EnvServerToken = "VAGRANT_SERVER_TOKEN"
|
|
|
|
// EnvContext specifies a named context to load.
|
|
EnvContext = "VAGRANT_CONTEXT"
|
|
)
|
|
|
|
// This is a weird type that only exists to satisify the interface required by
|
|
// grpc.WithPerRPCCredentials. That api is designed to incorporate things like OAuth
|
|
// but in our case, we really just want to send this static token through, but we still
|
|
// need to the dance.
|
|
type StaticToken string
|
|
|
|
func (t StaticToken) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
|
|
return map[string]string{
|
|
"authorization": string(t),
|
|
}, nil
|
|
}
|
|
|
|
func (t StaticToken) RequireTransportSecurity() bool {
|
|
return false
|
|
}
|