From 4563bf2b8580d1a3ef131344e10ab2fcac8bdbd1 Mon Sep 17 00:00:00 2001 From: Dmitriy Matrenichev Date: Mon, 1 Apr 2024 19:54:44 +0300 Subject: [PATCH] fix: remove freshly created `siderolink.Link` if PeerEvent ended with error Current code assumes if `WireguardHandler.PeerEvent` fails on the first connection attempt, it will never try to do `PeerEvent` it again. That is because on the second Talos provision attempt, the check `spec.NodePublicKey != req.NodePublicKey` will return `false`, and `dirty` will be also set to `false`. So Omni will just happily return the `link` and be done with it. That means if (for some reason) `WireguardHandler.PeerEvent` failed on first connection from Talos - it will never configure Wireguard on Omni side until Omni restarts, and you can actually see why `PeerEvent` fails in `pollWireguardPeers`. Fixing this by deleting freshly created `siderolink.Link` if `PeerEvent` failed. Signed-off-by: Dmitriy Matrenichev --- internal/pkg/siderolink/manager.go | 28 ++++++- internal/pkg/siderolink/siderolink_test.go | 85 ++++++++++++++++++---- 2 files changed, 97 insertions(+), 16 deletions(-) diff --git a/internal/pkg/siderolink/manager.go b/internal/pkg/siderolink/manager.go index 39d5aa9b..36a365e9 100644 --- a/internal/pkg/siderolink/manager.go +++ b/internal/pkg/siderolink/manager.go @@ -631,11 +631,13 @@ func (manager *Manager) getLink(ctx context.Context, req *pb.ProvisionRequest, i func (manager *Manager) Provision(ctx context.Context, req *pb.ProvisionRequest) (*pb.ProvisionResponse, error) { ctx = actor.MarkContextAsInternalActor(ctx) - link, dirty, err := manager.getLink(ctx, req, req.NodeUuid) + link, created, err := manager.getLink(ctx, req, req.NodeUuid) if err != nil { return nil, err } + var updated isDirty + spec := link.TypedSpec().Value if spec.NodePublicKey != req.NodePublicKey { if _, err = safe.StateUpdateWithConflicts(ctx, manager.state, link.Metadata(), func(r *siderolink.Link) error { @@ -654,11 +656,31 @@ func (manager *Manager) Provision(ctx context.Context, req *pb.ProvisionRequest) return nil, err } - dirty = true + updated = true } - if dirty { + if created || updated { if err = manager.wgHandler.PeerEvent(ctx, spec, false); err != nil { + if !created { + return nil, err + } + + // if the peer event fails and the link was just created, we need to teardown it + teardown, tdErr := manager.state.Teardown(ctx, link.Metadata()) + if tdErr != nil { + return nil, tdErr + } + + if !teardown { + return nil, err + } + + // try to destroy the link immediately` if no finalizer is set + destroyErr := manager.state.Destroy(ctx, link.Metadata()) + if destroyErr != nil { + return nil, destroyErr + } + return nil, err } } diff --git a/internal/pkg/siderolink/siderolink_test.go b/internal/pkg/siderolink/siderolink_test.go index fbc6366f..9a5a70b3 100644 --- a/internal/pkg/siderolink/siderolink_test.go +++ b/internal/pkg/siderolink/siderolink_test.go @@ -10,7 +10,6 @@ import ( "errors" "net/netip" "sync" - "sync/atomic" "testing" "time" @@ -20,6 +19,7 @@ import ( "github.com/cosi-project/runtime/pkg/state" "github.com/cosi-project/runtime/pkg/state/impl/inmem" "github.com/cosi-project/runtime/pkg/state/impl/namespaced" + "github.com/siderolabs/gen/ensure" "github.com/siderolabs/go-retry/retry" pb "github.com/siderolabs/siderolink/api/siderolink" "github.com/stretchr/testify/assert" @@ -39,9 +39,11 @@ import ( sideromanager "github.com/siderolabs/omni/internal/pkg/siderolink" ) +//nolint:govet type fakeWireguardHandler struct { - logger *zap.Logger - loggerMu sync.Mutex + logger *zap.Logger + loggerMu sync.Mutex + peerEventFailFor map[string]struct{} // public key for which to fail the peer event } func (h *fakeWireguardHandler) SetupDevice(netip.Prefix, wgtypes.Key, string, *zap.Logger) error { @@ -68,6 +70,10 @@ func (h *fakeWireguardHandler) PeerEvent(_ context.Context, spec *specs.Sideroli h.loggerMu.Lock() defer h.loggerMu.Unlock() + if _, ok := h.peerEventFailFor[spec.NodePublicKey]; ok { + return errors.New("peer event failed") + } + msg := "updated peer" if deleted { msg = "removed peer" @@ -108,7 +114,19 @@ func (suite *SiderolinkSuite) SetupTest() { var err error - suite.manager, err = sideromanager.NewManager(suite.ctx, suite.state, &fakeWireguardHandler{}, params, zaptest.NewLogger(suite.T()), nil, nil) + suite.manager, err = sideromanager.NewManager( + suite.ctx, + suite.state, + &fakeWireguardHandler{ + peerEventFailFor: map[string]struct{}{ + peerEvenFailWGKey.PublicKey().String(): {}, + }, + }, + params, + zaptest.NewLogger(suite.T()), + nil, + nil, + ) suite.Require().NoError(err) suite.startManager(params) @@ -237,6 +255,55 @@ func (suite *SiderolinkSuite) TestNodes() { suite.Require().Equal(privateKey.PublicKey().String(), resource.TypedSpec().Value.NodePublicKey) } +var peerEvenFailWGKey = ensure.Value(wgtypes.GeneratePrivateKey()) + +func (suite *SiderolinkSuite) TestPeerEventShouldFail() { + var spec *specs.ConnectionParamsSpec + + ctx, cancel := context.WithTimeout(suite.ctx, time.Second*2) + defer cancel() + + rtestutils.AssertResources[*siderolink.Config](ctx, suite.T(), suite.state, []string{ + siderolink.ConfigID, + }, func(r *siderolink.Config, assertion *assert.Assertions) { + assertion.NotEmpty(r.TypedSpec().Value.JoinToken) + assertion.NotEmpty(r.TypedSpec().Value.PrivateKey) + assertion.NotEmpty(r.TypedSpec().Value.PublicKey) + }) + + rtestutils.AssertResources[*siderolink.ConnectionParams](ctx, suite.T(), suite.state, []string{ + siderolink.ConfigID, + }, func(r *siderolink.ConnectionParams, assertion *assert.Assertions) { + assertion.NotEmpty(r.TypedSpec().Value.Args) + assertion.NotEmpty(r.TypedSpec().Value.ApiEndpoint) + assertion.NotEmpty(r.TypedSpec().Value.JoinToken) + assertion.NotEmpty(r.TypedSpec().Value.WireguardEndpoint) + + spec = r.TypedSpec().Value + }) + + conn, err := grpc.DialContext(suite.ctx, suite.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + suite.Require().NoError(err) + + client := pb.NewProvisionServiceClient(conn) + + _, err = client.Provision(suite.ctx, &pb.ProvisionRequest{ + NodeUuid: "testnode", + NodePublicKey: peerEvenFailWGKey.PublicKey().String(), + JoinToken: &spec.JoinToken, + }) + + suite.Require().Error(err) + + _, err = client.Provision(suite.ctx, &pb.ProvisionRequest{ + NodeUuid: "testnode", + NodePublicKey: peerEvenFailWGKey.PublicKey().String(), + JoinToken: &spec.JoinToken, + }) + + suite.Assert().Error(err) +} + func (suite *SiderolinkSuite) TestGenerateJoinToken() { token, err := sideromanager.GenerateJoinToken() @@ -262,13 +329,5 @@ func TestSiderolinkSuite(t *testing.T) { func safeLock(mx sync.Locker) func() { mx.Lock() - var locked atomic.Bool - - locked.Store(true) - - return func() { - if locked.Swap(false) { - mx.Unlock() - } - } + return sync.OnceFunc(mx.Unlock) }