From 84201981ea2688e406a3fd119764ac1f8586863e Mon Sep 17 00:00:00 2001 From: Dmitriy Matrenichev Date: Thu, 28 Mar 2024 19:35:45 +0300 Subject: [PATCH] chore: support WG over GRPC in Omni This PR adds the support for WG over GRPC. New field `VirtualAddrport` in `SiderolinkSpec` should allow for both setting the virtual addr and loading it after the Omni restart. Signed-off-by: Dmitriy Matrenichev --- client/api/omni/specs/siderolink.pb.go | 71 +++++----- client/api/omni/specs/siderolink.proto | 1 + .../api/omni/specs/siderolink_vtproto.pb.go | 47 +++++++ go.mod | 4 +- go.sum | 8 +- hack/compose/docker-compose.yml | 2 + internal/pkg/siderolink/manager.go | 122 ++++++++++++++--- internal/pkg/siderolink/siderolink_test.go | 125 +++++++++++++++++- internal/pkg/siderolink/wireguard.go | 31 +++-- 9 files changed, 338 insertions(+), 73 deletions(-) diff --git a/client/api/omni/specs/siderolink.pb.go b/client/api/omni/specs/siderolink.pb.go index 8d4d90d0..c14757dc 100644 --- a/client/api/omni/specs/siderolink.pb.go +++ b/client/api/omni/specs/siderolink.pb.go @@ -124,10 +124,11 @@ type SiderolinkSpec struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - NodeSubnet string `protobuf:"bytes,1,opt,name=node_subnet,json=nodeSubnet,proto3" json:"node_subnet,omitempty"` - NodePublicKey string `protobuf:"bytes,2,opt,name=node_public_key,json=nodePublicKey,proto3" json:"node_public_key,omitempty"` - LastEndpoint string `protobuf:"bytes,3,opt,name=last_endpoint,json=lastEndpoint,proto3" json:"last_endpoint,omitempty"` - Connected bool `protobuf:"varint,4,opt,name=connected,proto3" json:"connected,omitempty"` + NodeSubnet string `protobuf:"bytes,1,opt,name=node_subnet,json=nodeSubnet,proto3" json:"node_subnet,omitempty"` + NodePublicKey string `protobuf:"bytes,2,opt,name=node_public_key,json=nodePublicKey,proto3" json:"node_public_key,omitempty"` + LastEndpoint string `protobuf:"bytes,3,opt,name=last_endpoint,json=lastEndpoint,proto3" json:"last_endpoint,omitempty"` + Connected bool `protobuf:"varint,4,opt,name=connected,proto3" json:"connected,omitempty"` + VirtualAddrport string `protobuf:"bytes,7,opt,name=virtual_addrport,json=virtualAddrport,proto3" json:"virtual_addrport,omitempty"` } func (x *SiderolinkSpec) Reset() { @@ -190,6 +191,13 @@ func (x *SiderolinkSpec) GetConnected() bool { return false } +func (x *SiderolinkSpec) GetVirtualAddrport() string { + if x != nil { + return x.VirtualAddrport + } + return "" +} + // SiderolinkConnectionSpec describes each node connection information. type SiderolinkCounterSpec struct { state protoimpl.MessageState @@ -355,7 +363,7 @@ var file_omni_specs_siderolink_proto_rawDesc = []byte{ 0x52, 0x09, 0x6a, 0x6f, 0x69, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2f, 0x0a, 0x13, 0x61, 0x64, 0x76, 0x65, 0x72, 0x74, 0x69, 0x73, 0x65, 0x64, 0x5f, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x61, 0x64, 0x76, 0x65, 0x72, 0x74, - 0x69, 0x73, 0x65, 0x64, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x22, 0xa8, 0x01, 0x0a, + 0x69, 0x73, 0x65, 0x64, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x22, 0xd3, 0x01, 0x0a, 0x0e, 0x53, 0x69, 0x64, 0x65, 0x72, 0x6f, 0x6c, 0x69, 0x6e, 0x6b, 0x53, 0x70, 0x65, 0x63, 0x12, 0x1f, 0x0a, 0x0b, 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x73, 0x75, 0x62, 0x6e, 0x65, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6e, 0x6f, 0x64, 0x65, 0x53, 0x75, 0x62, 0x6e, 0x65, 0x74, @@ -365,31 +373,34 @@ var file_omni_specs_siderolink_proto_rawDesc = []byte{ 0x5f, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x6c, 0x61, 0x73, 0x74, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x4a, 0x04, 0x08, 0x05, 0x10, - 0x06, 0x4a, 0x04, 0x08, 0x06, 0x10, 0x07, 0x22, 0x98, 0x01, 0x0a, 0x15, 0x53, 0x69, 0x64, 0x65, - 0x72, 0x6f, 0x6c, 0x69, 0x6e, 0x6b, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x53, 0x70, 0x65, - 0x63, 0x12, 0x25, 0x0a, 0x0e, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x72, 0x65, 0x63, 0x65, 0x69, - 0x76, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x62, 0x79, 0x74, 0x65, 0x73, - 0x52, 0x65, 0x63, 0x65, 0x69, 0x76, 0x65, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x62, 0x79, 0x74, 0x65, - 0x73, 0x5f, 0x73, 0x65, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x62, 0x79, - 0x74, 0x65, 0x73, 0x53, 0x65, 0x6e, 0x74, 0x12, 0x39, 0x0a, 0x0a, 0x6c, 0x61, 0x73, 0x74, 0x5f, - 0x61, 0x6c, 0x69, 0x76, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, - 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, - 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x6c, 0x61, 0x73, 0x74, 0x41, 0x6c, 0x69, - 0x76, 0x65, 0x22, 0x9b, 0x01, 0x0a, 0x14, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, - 0x6e, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x53, 0x70, 0x65, 0x63, 0x12, 0x12, 0x0a, 0x04, 0x61, - 0x72, 0x67, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x61, 0x72, 0x67, 0x73, 0x12, - 0x21, 0x0a, 0x0c, 0x61, 0x70, 0x69, 0x5f, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x61, 0x70, 0x69, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x12, 0x2d, 0x0a, 0x12, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x5f, - 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x11, - 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, - 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x6a, 0x6f, 0x69, 0x6e, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x6a, 0x6f, 0x69, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, - 0x42, 0x32, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, - 0x69, 0x64, 0x65, 0x72, 0x6f, 0x6c, 0x61, 0x62, 0x73, 0x2f, 0x6f, 0x6d, 0x6e, 0x69, 0x2f, 0x63, - 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x6f, 0x6d, 0x6e, 0x69, 0x2f, 0x73, - 0x70, 0x65, 0x63, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x29, 0x0a, 0x10, 0x76, + 0x69, 0x72, 0x74, 0x75, 0x61, 0x6c, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x70, 0x6f, 0x72, 0x74, 0x18, + 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x76, 0x69, 0x72, 0x74, 0x75, 0x61, 0x6c, 0x41, 0x64, + 0x64, 0x72, 0x70, 0x6f, 0x72, 0x74, 0x4a, 0x04, 0x08, 0x05, 0x10, 0x06, 0x4a, 0x04, 0x08, 0x06, + 0x10, 0x07, 0x22, 0x98, 0x01, 0x0a, 0x15, 0x53, 0x69, 0x64, 0x65, 0x72, 0x6f, 0x6c, 0x69, 0x6e, + 0x6b, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x53, 0x70, 0x65, 0x63, 0x12, 0x25, 0x0a, 0x0e, + 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x72, 0x65, 0x63, 0x65, 0x69, 0x76, 0x65, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x65, 0x63, 0x65, 0x69, + 0x76, 0x65, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x62, 0x79, 0x74, 0x65, 0x73, 0x5f, 0x73, 0x65, 0x6e, + 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x62, 0x79, 0x74, 0x65, 0x73, 0x53, 0x65, + 0x6e, 0x74, 0x12, 0x39, 0x0a, 0x0a, 0x6c, 0x61, 0x73, 0x74, 0x5f, 0x61, 0x6c, 0x69, 0x76, 0x65, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, + 0x6d, 0x70, 0x52, 0x09, 0x6c, 0x61, 0x73, 0x74, 0x41, 0x6c, 0x69, 0x76, 0x65, 0x22, 0x9b, 0x01, + 0x0a, 0x14, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x61, 0x72, 0x61, + 0x6d, 0x73, 0x53, 0x70, 0x65, 0x63, 0x12, 0x12, 0x0a, 0x04, 0x61, 0x72, 0x67, 0x73, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x61, 0x72, 0x67, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x70, + 0x69, 0x5f, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0b, 0x61, 0x70, 0x69, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x2d, 0x0a, + 0x12, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x5f, 0x65, 0x6e, 0x64, 0x70, 0x6f, + 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x11, 0x77, 0x69, 0x72, 0x65, 0x67, + 0x75, 0x61, 0x72, 0x64, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1d, 0x0a, 0x0a, + 0x6a, 0x6f, 0x69, 0x6e, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x09, 0x6a, 0x6f, 0x69, 0x6e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x42, 0x32, 0x5a, 0x30, 0x67, + 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x69, 0x64, 0x65, 0x72, 0x6f, + 0x6c, 0x61, 0x62, 0x73, 0x2f, 0x6f, 0x6d, 0x6e, 0x69, 0x2f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x6f, 0x6d, 0x6e, 0x69, 0x2f, 0x73, 0x70, 0x65, 0x63, 0x73, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/client/api/omni/specs/siderolink.proto b/client/api/omni/specs/siderolink.proto index b88bdaeb..69316adb 100644 --- a/client/api/omni/specs/siderolink.proto +++ b/client/api/omni/specs/siderolink.proto @@ -24,6 +24,7 @@ message SiderolinkSpec { bool connected = 4; reserved 5; reserved 6; + string virtual_addrport = 7; } // SiderolinkConnectionSpec describes each node connection information. diff --git a/client/api/omni/specs/siderolink_vtproto.pb.go b/client/api/omni/specs/siderolink_vtproto.pb.go index da8880a9..54f262d6 100644 --- a/client/api/omni/specs/siderolink_vtproto.pb.go +++ b/client/api/omni/specs/siderolink_vtproto.pb.go @@ -54,6 +54,7 @@ func (m *SiderolinkSpec) CloneVT() *SiderolinkSpec { r.NodePublicKey = m.NodePublicKey r.LastEndpoint = m.LastEndpoint r.Connected = m.Connected + r.VirtualAddrport = m.VirtualAddrport if len(m.unknownFields) > 0 { r.unknownFields = make([]byte, len(m.unknownFields)) copy(r.unknownFields, m.unknownFields) @@ -159,6 +160,9 @@ func (this *SiderolinkSpec) EqualVT(that *SiderolinkSpec) bool { if this.Connected != that.Connected { return false } + if this.VirtualAddrport != that.VirtualAddrport { + return false + } return string(this.unknownFields) == string(that.unknownFields) } @@ -334,6 +338,13 @@ func (m *SiderolinkSpec) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i -= len(m.unknownFields) copy(dAtA[i:], m.unknownFields) } + if len(m.VirtualAddrport) > 0 { + i -= len(m.VirtualAddrport) + copy(dAtA[i:], m.VirtualAddrport) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.VirtualAddrport))) + i-- + dAtA[i] = 0x3a + } if m.Connected { i-- if m.Connected { @@ -541,6 +552,10 @@ func (m *SiderolinkSpec) SizeVT() (n int) { if m.Connected { n += 2 } + l = len(m.VirtualAddrport) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } n += len(m.unknownFields) return n } @@ -1011,6 +1026,38 @@ func (m *SiderolinkSpec) UnmarshalVT(dAtA []byte) error { } } m.Connected = bool(v != 0) + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field VirtualAddrport", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.VirtualAddrport = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex default: iNdEx = preIndex skippy, err := protohelpers.Skip(dAtA[iNdEx:]) diff --git a/go.mod b/go.mod index c46b7f8b..8c74c2ac 100644 --- a/go.mod +++ b/go.mod @@ -69,7 +69,7 @@ require ( github.com/siderolabs/image-factory v0.2.2 github.com/siderolabs/kms-client v0.1.0 github.com/siderolabs/omni/client v0.0.0-00010101000000-000000000000 - github.com/siderolabs/siderolink v0.3.4 + github.com/siderolabs/siderolink v0.3.5 github.com/siderolabs/talos/pkg/machinery v1.7.0-beta.0 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.8.0 @@ -87,6 +87,7 @@ require ( golang.org/x/sync v0.6.0 golang.org/x/text v0.14.0 golang.org/x/tools v0.19.0 + golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 google.golang.org/grpc v1.62.1 google.golang.org/protobuf v1.33.0 @@ -237,7 +238,6 @@ require ( golang.org/x/term v0.18.0 // indirect golang.org/x/time v0.5.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240401170217-c3f982113cda // indirect diff --git a/go.sum b/go.sum index dbbea9e3..0ad3d0bd 100644 --- a/go.sum +++ b/go.sum @@ -491,8 +491,8 @@ github.com/siderolabs/net v0.4.0 h1:1bOgVay/ijPkJz4qct98nHsiB/ysLQU0KLoBC4qLm7I= github.com/siderolabs/net v0.4.0/go.mod h1:/ibG+Hm9HU27agp5r9Q3eZicEfjquzNzQNux5uEk0kM= github.com/siderolabs/protoenc v0.2.1 h1:BqxEmeWQeMpNP3R6WrPqDatX8sM/r4t97OP8mFmg6GA= github.com/siderolabs/protoenc v0.2.1/go.mod h1:StTHxjet1g11GpNAWiATgc8K0HMKiFSEVVFOa/H0otc= -github.com/siderolabs/siderolink v0.3.4 h1:850JRSSrD3EEDh35h6wiSTtRiGuclEc/6k4wx/It4nU= -github.com/siderolabs/siderolink v0.3.4/go.mod h1:juxlSF9cBzeBHsOjS7hVS3s0NDpC034i/OZunVReqmo= +github.com/siderolabs/siderolink v0.3.5 h1:sU4WNGCRGQYZ/sQZaVQbGfUNOqS561oL4kafKlo4FDY= +github.com/siderolabs/siderolink v0.3.5/go.mod h1:/7Dg0Nkh4q/8yqsY/VirDOTOFOqRvPikagCoyf3+Mf4= github.com/siderolabs/talos/pkg/machinery v1.7.0-beta.0 h1:fOn3uKNKA1bzHGCeOoaE8Dy40UH9Z6PHaf/XYdFwVy8= github.com/siderolabs/talos/pkg/machinery v1.7.0-beta.0/go.mod h1:YBl9KDCD45Uc7N0rXBY1JqovUn1n46ekUPSNbEVZzQU= github.com/siderolabs/tcpproxy v0.1.0 h1:IbkS9vRhjMOscc1US3M5P1RnsGKFgB6U5IzUk+4WkKA= @@ -738,8 +738,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= -golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb h1:c5tyN8sSp8jSDxdCCDXVOpJwYXXhmTkNMt+g0zTSOic= -golang.zx2c4.com/wireguard v0.0.0-20231022001213-2e0774f246fb/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= +golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4= +golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= diff --git a/hack/compose/docker-compose.yml b/hack/compose/docker-compose.yml index 3bc585db..31624b53 100644 --- a/hack/compose/docker-compose.yml +++ b/hack/compose/docker-compose.yml @@ -7,6 +7,8 @@ version: '3.8' services: omni: network_mode: host + devices: + - /dev/net/tun depends_on: - vault-dev - node-dev diff --git a/internal/pkg/siderolink/manager.go b/internal/pkg/siderolink/manager.go index 39d5aa9b..495ffc3a 100644 --- a/internal/pkg/siderolink/manager.go +++ b/internal/pkg/siderolink/manager.go @@ -11,7 +11,6 @@ import ( "crypto/tls" "errors" "fmt" - "io" "net" "net/netip" "os" @@ -25,15 +24,19 @@ import ( "github.com/jxskiss/base62" "github.com/prometheus/client_golang/prometheus" "github.com/siderolabs/gen/channel" + "github.com/siderolabs/go-pointer" "github.com/siderolabs/go-retry/retry" eventsapi "github.com/siderolabs/siderolink/api/events" pb "github.com/siderolabs/siderolink/api/siderolink" "github.com/siderolabs/siderolink/pkg/events" + "github.com/siderolabs/siderolink/pkg/wgtunnel/wgbind" + "github.com/siderolabs/siderolink/pkg/wgtunnel/wggrpc" "github.com/siderolabs/siderolink/pkg/wireguard" machineapi "github.com/siderolabs/talos/pkg/machinery/api/machine" talosconstants "github.com/siderolabs/talos/pkg/machinery/constants" "github.com/siderolabs/talos/pkg/machinery/proto" "go.uber.org/zap" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -61,6 +64,9 @@ type LinkCounterDelta struct { BytesReceived int64 } +// maxPendingClientMessages sets the maximum number of messages for queue "from peers" after which it will block. +const maxPendingClientMessages = 100 + // NewManager creates new Manager. func NewManager( ctx context.Context, @@ -72,12 +78,10 @@ func NewManager( deltaCh chan<- LinkCounterDeltas, ) (*Manager, error) { manager := &Manager{ + logger: logger, state: state, wgHandler: wgHandler, - logger: logger, logHandler: handler, - deltaCh: deltaCh, - metricBytesReceived: prometheus.NewCounter(prometheus.CounterOpts{ Name: "omni_siderolink_received_bytes_total", Help: "Number of bytes received from the SideroLink interface.", @@ -98,6 +102,10 @@ func NewManager( 64 * 60, // more than hour... wth? }, }), + deltaCh: deltaCh, + allowedPeers: wggrpc.NewAllowedPeers(), + peerTraffic: wgbind.NewPeerTraffic(maxPendingClientMessages), + virtualPrefix: wireguard.VirtualNetworkPrefix(), } nodePrefix := wireguard.NetworkPrefix("") @@ -131,6 +139,9 @@ type Manager struct { metricLastHandshake prometheus.Histogram deltaCh chan<- LinkCounterDeltas serverAddr netip.Prefix + allowedPeers *wggrpc.AllowedPeers + peerTraffic *wgbind.PeerTraffic + virtualPrefix netip.Prefix } // JoinTokenLen number of random bytes to be encoded in the join token. @@ -230,10 +241,9 @@ func createListener(ctx context.Context, host, port string) (net.Listener, error } // Register implements controller.Manager interface. -func (manager *Manager) Register( - server *grpc.Server, -) { +func (manager *Manager) Register(server *grpc.Server) { pb.RegisterProvisionServiceServer(server, manager) + pb.RegisterWireGuardOverGRPCServiceServer(server, wggrpc.NewService(manager.peerTraffic, manager.allowedPeers, manager.logger)) } // Run implements controller.Manager interface. @@ -347,7 +357,28 @@ func (manager *Manager) startWireguard(ctx context.Context, eg *errgroup.Group, return fmt.Errorf("invalid private key: %w", err) } - if err = manager.wgHandler.SetupDevice(serverAddr, key, manager.wgConfig().WireguardEndpoint, manager.logger); err != nil { + _, strPort, err := net.SplitHostPort(manager.wgConfig().WireguardEndpoint) + if err != nil { + return fmt.Errorf("invalid Wireguard endpoint: %w", err) + } + + port, err := strconv.Atoi(strPort) + if err != nil { + return fmt.Errorf("invalid Wireguard endpoint port: %w", err) + } + + peerHandler := &peerHandler{ + allowedPeers: manager.allowedPeers, + } + + if err = manager.wgHandler.SetupDevice(wireguard.DeviceConfig{ + Bind: wgbind.NewServerBind(conn.NewDefaultBind(), manager.virtualPrefix, manager.peerTraffic, manager.logger), + PeerHandler: peerHandler, + Logger: manager.logger, + ServerPrefix: serverAddr, + PrivateKey: key, + ListenPort: uint16(port), + }); err != nil { return err } @@ -637,7 +668,7 @@ func (manager *Manager) Provision(ctx context.Context, req *pb.ProvisionRequest) } spec := link.TypedSpec().Value - if spec.NodePublicKey != req.NodePublicKey { + if spec.NodePublicKey != req.NodePublicKey || tunnelStatusChanged(req, link) { if _, err = safe.StateUpdateWithConflicts(ctx, manager.state, link.Metadata(), func(r *siderolink.Link) error { s := r.TypedSpec().Value @@ -646,6 +677,10 @@ func (manager *Manager) Provision(ctx context.Context, req *pb.ProvisionRequest) } s.NodePublicKey = req.NodePublicKey + s.VirtualAddrport, err = manager.generateVirtualAddrPort(pointer.SafeDeref(req.WireguardOverGrpc)) + if err != nil { + return err + } spec = s @@ -668,41 +703,66 @@ func (manager *Manager) Provision(ctx context.Context, req *pb.ProvisionRequest) endpoint = manager.wgConfig().AdvertisedEndpoint } + // If the virtual address is set, use it as the endpoint to prevent the client from connecting to the actual WG endpoint + if spec.VirtualAddrport != "" { + endpoint = spec.VirtualAddrport + } + return &pb.ProvisionResponse{ - ServerEndpoint: []string{endpoint}, + ServerEndpoint: pb.MakeEndpoints(endpoint), ServerPublicKey: manager.wgConfig().PublicKey, - ServerAddress: manager.wgConfig().ServerAddress, NodeAddressPrefix: spec.NodeSubnet, + ServerAddress: manager.wgConfig().ServerAddress, + GrpcPeerAddrPort: spec.VirtualAddrport, }, nil } +func tunnelStatusChanged(req *pb.ProvisionRequest, link *siderolink.Link) bool { + wgOverGRPC := pointer.SafeDeref(req.WireguardOverGrpc) + virtualAddrPort := link.TypedSpec().Value.VirtualAddrport + + return wgOverGRPC == (virtualAddrPort == "") +} + func (manager *Manager) generateLinkSpec(req *pb.ProvisionRequest) (*specs.SiderolinkSpec, error) { nodePrefix := netip.MustParsePrefix(manager.wgConfig().Subnet) // generated random address for the node - raw := nodePrefix.Addr().As16() - salt := make([]byte, 8) - - _, err := io.ReadFull(rand.Reader, salt) + nodeAddress, err := wireguard.GenerateRandomNodeAddr(nodePrefix) if err != nil { - return nil, err + return nil, fmt.Errorf("error generating random node address: %w", err) } - copy(raw[8:], salt) - - nodeAddress := netip.PrefixFrom(netip.AddrFrom16(raw), nodePrefix.Bits()) - pubKey, err := wgtypes.ParseKey(req.NodePublicKey) if err != nil { return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("error parsing Wireguard key: %s", err)) } + virtualAddrPort, err := manager.generateVirtualAddrPort(pointer.SafeDeref(req.WireguardOverGrpc)) + if err != nil { + return nil, err + } + return &specs.SiderolinkSpec{ - NodeSubnet: nodeAddress.String(), - NodePublicKey: pubKey.String(), + NodeSubnet: nodeAddress.String(), + NodePublicKey: pubKey.String(), + VirtualAddrport: virtualAddrPort, }, nil } +func (manager *Manager) generateVirtualAddrPort(generate bool) (string, error) { + if !generate { + return "", nil + } + + generated, err := wireguard.GenerateRandomNodeAddr(manager.virtualPrefix) + if err != nil { + return "", fmt.Errorf("error generating random virtual node address: %w", err) + } + + return net.JoinHostPort(generated.Addr().String(), "50889"), nil +} + // Describe implements prom.Collector interface. func (manager *Manager) Describe(ch chan<- *prometheus.Desc) { prometheus.DescribeByCollect(manager, ch) @@ -716,3 +776,21 @@ func (manager *Manager) Collect(ch chan<- prometheus.Metric) { } var _ prometheus.Collector = &Manager{} + +type peerHandler struct { + allowedPeers *wggrpc.AllowedPeers +} + +func (p *peerHandler) HandlePeerAdded(event wireguard.PeerEvent) error { + if event.VirtualAddr.IsValid() { + p.allowedPeers.AddToken(event.PubKey, event.VirtualAddr.String()) + } + + return nil +} + +func (p *peerHandler) HandlePeerRemoved(pubKey wgtypes.Key) error { + p.allowedPeers.RemoveToken(pubKey) + + return nil +} diff --git a/internal/pkg/siderolink/siderolink_test.go b/internal/pkg/siderolink/siderolink_test.go index fbc6366f..d1c2249a 100644 --- a/internal/pkg/siderolink/siderolink_test.go +++ b/internal/pkg/siderolink/siderolink_test.go @@ -8,7 +8,6 @@ package siderolink_test import ( "context" "errors" - "net/netip" "sync" "sync/atomic" "testing" @@ -20,8 +19,10 @@ import ( "github.com/cosi-project/runtime/pkg/state" "github.com/cosi-project/runtime/pkg/state/impl/inmem" "github.com/cosi-project/runtime/pkg/state/impl/namespaced" + "github.com/siderolabs/go-pointer" "github.com/siderolabs/go-retry/retry" pb "github.com/siderolabs/siderolink/api/siderolink" + "github.com/siderolabs/siderolink/pkg/wireguard" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" "go.uber.org/zap" @@ -44,7 +45,7 @@ type fakeWireguardHandler struct { loggerMu sync.Mutex } -func (h *fakeWireguardHandler) SetupDevice(netip.Prefix, wgtypes.Key, string, *zap.Logger) error { +func (h *fakeWireguardHandler) SetupDevice(wireguard.DeviceConfig) error { return nil } @@ -237,6 +238,126 @@ func (suite *SiderolinkSuite) TestNodes() { suite.Require().Equal(privateKey.PublicKey().String(), resource.TypedSpec().Value.NodePublicKey) } +func (suite *SiderolinkSuite) TestVirtualNodes() { + var spec *specs.ConnectionParamsSpec + + ctx, cancel := context.WithTimeout(suite.ctx, time.Second*2) + defer cancel() + + rtestutils.AssertResources[*siderolink.Config](ctx, suite.T(), suite.state, []string{ + siderolink.ConfigID, + }, func(r *siderolink.Config, assertion *assert.Assertions) { + assertion.NotEmpty(r.TypedSpec().Value.JoinToken) + assertion.NotEmpty(r.TypedSpec().Value.PrivateKey) + assertion.NotEmpty(r.TypedSpec().Value.PublicKey) + }) + + rtestutils.AssertResources[*siderolink.ConnectionParams](ctx, suite.T(), suite.state, []string{ + siderolink.ConfigID, + }, func(r *siderolink.ConnectionParams, assertion *assert.Assertions) { + assertion.NotEmpty(r.TypedSpec().Value.Args) + assertion.NotEmpty(r.TypedSpec().Value.ApiEndpoint) + assertion.NotEmpty(r.TypedSpec().Value.JoinToken) + assertion.NotEmpty(r.TypedSpec().Value.WireguardEndpoint) + + spec = r.TypedSpec().Value + }) + + conn, err := grpc.DialContext(suite.ctx, suite.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + suite.Require().NoError(err) + + client := pb.NewProvisionServiceClient(conn) + + privateKey, err := wgtypes.GeneratePrivateKey() + suite.Require().NoError(err) + + resp, err := client.Provision(suite.ctx, &pb.ProvisionRequest{ + NodeUuid: "testnode", + NodePublicKey: privateKey.PublicKey().String(), + JoinToken: &spec.JoinToken, + WireguardOverGrpc: pointer.To(true), + }) + + suite.Require().NoError(err) + + suite.Assert().NoError( + retry.Constant(time.Second * 2).Retry(func() error { + list, err := safe.ReaderList[*siderolink.Link](suite.ctx, suite.state, resource.NewMetadata(siderolink.Namespace, siderolink.LinkType, "", resource.VersionUndefined)) //nolint:govet + if err != nil { + return err + } + + if list.Len() == 0 { + return retry.ExpectedErrorf("no links established yet") + } + + for it := list.Iterator(); it.Next(); { + item := it.Value() + + if item.Metadata().ID() == "" { + return errors.New("empty id in the resource list") + } + + if item.TypedSpec().Value.VirtualAddrport == "" { + return errors.New("empty virtual address in the resource list") + } + } + + return nil + }), + ) + + reprovision, err := client.Provision(suite.ctx, &pb.ProvisionRequest{ + NodeUuid: "testnode", + NodePublicKey: privateKey.PublicKey().String(), + JoinToken: &spec.JoinToken, + }) + + expectedResp := resp.CloneVT() + expectedResp.GrpcPeerAddrPort = "" + expectedResp.ServerEndpoint = pb.MakeEndpoints(config.Config.SiderolinkWireguardAdvertisedAddress) + + suite.Assert().NoError(err) + + suite.Require().Equal(expectedResp.String(), reprovision.String()) + + privateKey, err = wgtypes.GeneratePrivateKey() + suite.Assert().NoError(err) + + reprovision, err = client.Provision(suite.ctx, &pb.ProvisionRequest{ + NodeUuid: "testnode", + NodePublicKey: privateKey.PublicKey().String(), + JoinToken: &spec.JoinToken, + }) + + suite.Assert().NoError(err) + suite.Require().Equal(expectedResp.String(), reprovision.String()) + + res, err := safe.StateGet[*siderolink.Link](suite.ctx, suite.state, resource.NewMetadata(siderolink.Namespace, siderolink.LinkType, "testnode", resource.VersionUndefined)) + suite.Assert().NoError(err) + suite.Require().Equal(privateKey.PublicKey().String(), res.TypedSpec().Value.NodePublicKey) + suite.Require().Zero(res.TypedSpec().Value.VirtualAddrport) + + reprovision, err = client.Provision(suite.ctx, &pb.ProvisionRequest{ + NodeUuid: "testnode", + NodePublicKey: privateKey.PublicKey().String(), + JoinToken: &spec.JoinToken, + WireguardOverGrpc: pointer.To(true), + }) + + resp.GrpcPeerAddrPort = reprovision.GrpcPeerAddrPort + resp.ServerEndpoint = reprovision.ServerEndpoint + + suite.Assert().NoError(err) + suite.Require().Equal(resp.String(), reprovision.String()) + + res, err = safe.StateGet[*siderolink.Link](suite.ctx, suite.state, resource.NewMetadata(siderolink.Namespace, siderolink.LinkType, "testnode", resource.VersionUndefined)) + suite.Assert().NoError(err) + suite.Require().Equal(privateKey.PublicKey().String(), res.TypedSpec().Value.NodePublicKey) + suite.Require().NotZero(res.TypedSpec().Value.VirtualAddrport) + suite.Require().Equal(reprovision.GrpcPeerAddrPort, res.TypedSpec().Value.VirtualAddrport) +} + func (suite *SiderolinkSuite) TestGenerateJoinToken() { token, err := sideromanager.GenerateJoinToken() diff --git a/internal/pkg/siderolink/wireguard.go b/internal/pkg/siderolink/wireguard.go index bbf6005d..6b7cb418 100644 --- a/internal/pkg/siderolink/wireguard.go +++ b/internal/pkg/siderolink/wireguard.go @@ -10,6 +10,7 @@ import ( "fmt" "net/netip" + "github.com/siderolabs/gen/channel" "github.com/siderolabs/go-pointer" "github.com/siderolabs/siderolink/pkg/wireguard" "go.uber.org/zap" @@ -20,7 +21,7 @@ import ( // WireguardHandler abstraction around peer handler and wgDevice. type WireguardHandler interface { - SetupDevice(netip.Prefix, wgtypes.Key, string, *zap.Logger) error + SetupDevice(wireguard.DeviceConfig) error Shutdown() error Run(context.Context, *zap.Logger) error PeerEvent(context.Context, *specs.SiderolinkSpec, bool) error @@ -48,16 +49,23 @@ func (handler *PhysicalWireguardHandler) PeerEvent(ctx context.Context, spec *sp return err } - select { - case <-ctx.Done(): - case handler.events <- wireguard.PeerEvent{ + var virtualAddrPort netip.AddrPort + + if spec.VirtualAddrport != "" { + virtualAddrPort, err = netip.ParseAddrPort(spec.VirtualAddrport) + if err != nil { + return err + } + } + + channel.SendWithContext(ctx, handler.events, wireguard.PeerEvent{ PubKey: pubKey, - Address: address.Addr(), Remove: deleted, Endpoint: spec.LastEndpoint, + Address: address.Addr(), PersistentKeepAliveInterval: pointer.To(wireguard.RecommendedPersistentKeepAliveInterval), - }: - } + VirtualAddr: virtualAddrPort.Addr(), + }) return nil } @@ -68,13 +76,10 @@ func (handler *PhysicalWireguardHandler) EventCh() <-chan wireguard.PeerEvent { } // SetupDevice implements WireguardHandler. -func (handler *PhysicalWireguardHandler) SetupDevice(serverAddr netip.Prefix, key wgtypes.Key, endpoint string, logger *zap.Logger) error { - wireguardEndpoint, err := netip.ParseAddrPort(endpoint) - if err != nil { - return fmt.Errorf("invalid Wireguard endpoint: %w", err) - } +func (handler *PhysicalWireguardHandler) SetupDevice(cfg wireguard.DeviceConfig) error { + var err error - handler.wgDevice, err = wireguard.NewDevice(serverAddr, key, wireguardEndpoint.Port(), false, logger) + handler.wgDevice, err = wireguard.NewDevice(cfg) if err != nil { return fmt.Errorf("error initializing wgDevice: %w", err) }