vaguerent/internal/server/grpc_version_test.go

197 lines
4.6 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package server
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/hashicorp/vagrant/internal/protocolversion"
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
)
func TestVersionUnaryInterceptor(t *testing.T) {
f := versionUnaryInterceptor(&vagrant_server.VersionInfo{
Api: &vagrant_server.VersionInfo_ProtocolVersion{
Current: 10,
Minimum: 2,
},
Entrypoint: &vagrant_server.VersionInfo_ProtocolVersion{
Current: 10,
Minimum: 5,
},
Version: "1.2.3",
})
t.Run("no headers", func(t *testing.T) {
require := require.New(t)
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs())
// Call
called := false
_, err := f(ctx, nil, &grpc.UnaryServerInfo{
FullMethod: "/hashicorp.vagrant.Vagrant/Foo",
}, func(
ctx context.Context,
req interface{},
) (interface{}, error) {
called = true
return nil, nil
})
require.False(called)
require.Error(err)
require.Contains(err.Error(), "is not set")
})
t.Run("no headers on GetVersionInfo", func(t *testing.T) {
require := require.New(t)
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs())
// Call
called := false
_, err := f(ctx, nil, &grpc.UnaryServerInfo{
FullMethod: "/hashicorp.vagrant.Vagrant/GetVersionInfo",
}, func(
ctx context.Context,
req interface{},
) (interface{}, error) {
called = true
return nil, nil
})
require.True(called)
require.NoError(err)
})
t.Run("no headers on a different service", func(t *testing.T) {
require := require.New(t)
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs())
// Call
called := false
_, err := f(ctx, nil, &grpc.UnaryServerInfo{
FullMethod: "/hashicorp.notvagrant.Vagrant/Foo",
}, func(
ctx context.Context,
req interface{},
) (interface{}, error) {
called = true
return nil, nil
})
require.True(called)
require.NoError(err)
})
t.Run("valid API", func(t *testing.T) {
require := require.New(t)
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(
protocolversion.HeaderClientApiProtocol, "4,7",
))
// Call
var actual context.Context
called := false
_, err := f(ctx, nil, &grpc.UnaryServerInfo{
FullMethod: "/hashicorp.vagrant.Vagrant/Foo",
}, func(
ctx context.Context,
req interface{},
) (interface{}, error) {
called = true
actual = ctx
return nil, nil
})
require.True(called)
require.NoError(err)
// Check metadata
require.Equal(uint32(7), protocolversion.FromContext(actual))
})
t.Run("invalid API", func(t *testing.T) {
require := require.New(t)
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(
protocolversion.HeaderClientApiProtocol, "11,14",
protocolversion.HeaderClientEntrypointProtocol, "4,7",
))
// Call
called := false
_, err := f(ctx, nil, &grpc.UnaryServerInfo{
FullMethod: "/hashicorp.vagrant.Vagrant/Foo",
}, func(
ctx context.Context,
req interface{},
) (interface{}, error) {
called = true
return nil, nil
})
require.False(called)
require.Error(err)
require.Contains(err.Error(), "outdated")
})
t.Run("valid Entrypoint", func(t *testing.T) {
require := require.New(t)
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(
protocolversion.HeaderClientApiProtocol, "4,7",
protocolversion.HeaderClientEntrypointProtocol, "4,6",
))
// Call
var actual context.Context
called := false
_, err := f(ctx, nil, &grpc.UnaryServerInfo{
FullMethod: "/hashicorp.vagrant.Vagrant/EntrypointFoo",
}, func(
ctx context.Context,
req interface{},
) (interface{}, error) {
called = true
actual = ctx
return nil, nil
})
require.True(called)
require.NoError(err)
// Check metadata
require.Equal(uint32(6), protocolversion.FromContext(actual))
})
t.Run("invalid Entrypoint", func(t *testing.T) {
require := require.New(t)
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(
protocolversion.HeaderClientApiProtocol, "4,7",
protocolversion.HeaderClientEntrypointProtocol, "2,3",
))
// Call
called := false
_, err := f(ctx, nil, &grpc.UnaryServerInfo{
FullMethod: "/hashicorp.vagrant.Vagrant/EntrypointFoo",
}, func(
ctx context.Context,
req interface{},
) (interface{}, error) {
called = true
return nil, nil
})
require.False(called)
require.Error(err)
require.Contains(err.Error(), "outdated")
})
}