159 lines
3.9 KiB
Go
159 lines
3.9 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package server
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/status"
|
|
|
|
"github.com/hashicorp/vagrant/internal/protocolversion"
|
|
"github.com/hashicorp/vagrant/internal/server/proto/vagrant_server"
|
|
)
|
|
|
|
// versionUnaryInterceptor returns a gRPC unary interceptor that negotiates
|
|
// the protocol version to use and sets it in the context using
|
|
// protocolversion.WithContext.
|
|
func versionUnaryInterceptor(serverInfo *vagrant_server.VersionInfo) grpc.UnaryServerInterceptor {
|
|
return func(
|
|
ctx context.Context,
|
|
req interface{},
|
|
info *grpc.UnaryServerInfo,
|
|
handler grpc.UnaryHandler) (interface{}, error) {
|
|
typ, ok := versionType(info.FullMethod)
|
|
if !ok {
|
|
return handler(ctx, req)
|
|
}
|
|
|
|
ctx, err := versionContext(ctx, typ, serverInfo)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return handler(ctx, req)
|
|
}
|
|
}
|
|
|
|
// versionStreamInterceptor returns a gRPC unary interceptor that negotiates
|
|
// the protocol version to use and sets it in the context using
|
|
// protocolversion.WithContext.
|
|
func versionStreamInterceptor(serverInfo *vagrant_server.VersionInfo) grpc.StreamServerInterceptor {
|
|
return func(
|
|
srv interface{},
|
|
ss grpc.ServerStream,
|
|
info *grpc.StreamServerInfo,
|
|
handler grpc.StreamHandler) error {
|
|
typ, ok := versionType(info.FullMethod)
|
|
if !ok {
|
|
return handler(srv, ss)
|
|
}
|
|
|
|
ctx := ss.Context()
|
|
ctx, err := versionContext(ctx, typ, serverInfo)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Invoke the handler.
|
|
return handler(srv, &versionStream{
|
|
ServerStream: ss,
|
|
context: ctx,
|
|
})
|
|
}
|
|
}
|
|
|
|
// versionType returns the type of protocol version we should negotiate.
|
|
func versionType(fullMethod string) (protocolversion.Type, bool) {
|
|
// Only care about vagrant APIs and ignore the version info call.
|
|
if !strings.HasPrefix(fullMethod, "/hashicorp.vagrant.Vagrant/") {
|
|
return protocolversion.Invalid, false
|
|
}
|
|
|
|
// Get the method
|
|
idx := strings.LastIndex(fullMethod, "/")
|
|
if idx == -1 {
|
|
return protocolversion.Invalid, false
|
|
}
|
|
method := fullMethod[idx+1:]
|
|
|
|
// If it is a version method we don't negotiate versions at all.
|
|
if method == "GetVersionInfo" {
|
|
return protocolversion.Invalid, false
|
|
}
|
|
|
|
// Determine what API is being called
|
|
typ := protocolversion.Api
|
|
if strings.HasPrefix(method, "Entrypoint") {
|
|
typ = protocolversion.Entrypoint
|
|
}
|
|
|
|
return typ, true
|
|
}
|
|
|
|
// versionContext
|
|
func versionContext(
|
|
ctx context.Context,
|
|
typ protocolversion.Type,
|
|
info *vagrant_server.VersionInfo,
|
|
) (context.Context, error) {
|
|
var header string
|
|
var server *vagrant_server.VersionInfo_ProtocolVersion
|
|
switch typ {
|
|
case protocolversion.Api:
|
|
header = protocolversion.HeaderClientApiProtocol
|
|
server = info.Api
|
|
|
|
case protocolversion.Entrypoint:
|
|
header = protocolversion.HeaderClientEntrypointProtocol
|
|
server = info.Entrypoint
|
|
|
|
default:
|
|
return nil, status.Errorf(codes.Internal, "invalid protocol type")
|
|
}
|
|
|
|
// Get our metadata
|
|
md, ok := metadata.FromIncomingContext(ctx)
|
|
if !ok {
|
|
return nil, status.Errorf(codes.InvalidArgument, "Retrieving metadata is failed")
|
|
}
|
|
|
|
// Get the client version information
|
|
vs := md[header]
|
|
if len(vs) != 1 {
|
|
return nil, status.Errorf(codes.InvalidArgument,
|
|
"required header %s is not set", header)
|
|
}
|
|
min, current, err := protocolversion.ParseHeader(vs[0])
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.InvalidArgument,
|
|
"header %q: %s", header, err)
|
|
}
|
|
|
|
// Negotiate the version to use
|
|
version, err := protocolversion.Negotiate(&vagrant_server.VersionInfo_ProtocolVersion{
|
|
Current: current,
|
|
Minimum: min,
|
|
}, server)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.InvalidArgument,
|
|
"header %q: %s", header, err)
|
|
}
|
|
|
|
// Invoke the handler.
|
|
return protocolversion.WithContext(ctx, version), nil
|
|
}
|
|
|
|
type versionStream struct {
|
|
grpc.ServerStream
|
|
context context.Context
|
|
}
|
|
|
|
func (s *versionStream) Context() context.Context {
|
|
return s.context
|
|
}
|