Skip to content

Commit

Permalink
fix: remove freshly created siderolink.Link if PeerEvent ended with…
Browse files Browse the repository at this point in the history
… error

Current code assumes if `WireguardHandler.PeerEvent` fails on the first connection attempt, it will never try to do `PeerEvent` it again. That is because on the second Talos provision attempt, the check `spec.NodePublicKey != req.NodePublicKey` will return `false`, and `dirty` will be also set to `false`. So Omni will just happily return the `link` and be done with it. That means if (for some reason) `WireguardHandler.PeerEvent` failed on first connection from Talos - it will never configure Wireguard on Omni side until Omni restarts, and you can actually see why `PeerEvent` fails in `pollWireguardPeers`.

Fixing this by deleting freshly created `siderolink.Link` if `PeerEvent` failed.

Signed-off-by: Dmitriy Matrenichev <[email protected]>
  • Loading branch information
DmitriyMV committed Apr 1, 2024
1 parent ae85293 commit 4563bf2
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 16 deletions.
28 changes: 25 additions & 3 deletions internal/pkg/siderolink/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -631,11 +631,13 @@ func (manager *Manager) getLink(ctx context.Context, req *pb.ProvisionRequest, i
func (manager *Manager) Provision(ctx context.Context, req *pb.ProvisionRequest) (*pb.ProvisionResponse, error) {
ctx = actor.MarkContextAsInternalActor(ctx)

link, dirty, err := manager.getLink(ctx, req, req.NodeUuid)
link, created, err := manager.getLink(ctx, req, req.NodeUuid)
if err != nil {
return nil, err
}

var updated isDirty

spec := link.TypedSpec().Value
if spec.NodePublicKey != req.NodePublicKey {
if _, err = safe.StateUpdateWithConflicts(ctx, manager.state, link.Metadata(), func(r *siderolink.Link) error {
Expand All @@ -654,11 +656,31 @@ func (manager *Manager) Provision(ctx context.Context, req *pb.ProvisionRequest)
return nil, err
}

dirty = true
updated = true
}

if dirty {
if created || updated {
if err = manager.wgHandler.PeerEvent(ctx, spec, false); err != nil {
if !created {
return nil, err
}

// if the peer event fails and the link was just created, we need to teardown it
teardown, tdErr := manager.state.Teardown(ctx, link.Metadata())
if tdErr != nil {
return nil, tdErr
}

if !teardown {
return nil, err
}

// try to destroy the link immediately` if no finalizer is set
destroyErr := manager.state.Destroy(ctx, link.Metadata())
if destroyErr != nil {
return nil, destroyErr
}

return nil, err
}
}
Expand Down
85 changes: 72 additions & 13 deletions internal/pkg/siderolink/siderolink_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"errors"
"net/netip"
"sync"
"sync/atomic"
"testing"
"time"

Expand All @@ -20,6 +19,7 @@ import (
"github.com/cosi-project/runtime/pkg/state"
"github.com/cosi-project/runtime/pkg/state/impl/inmem"
"github.com/cosi-project/runtime/pkg/state/impl/namespaced"
"github.com/siderolabs/gen/ensure"
"github.com/siderolabs/go-retry/retry"
pb "github.com/siderolabs/siderolink/api/siderolink"
"github.com/stretchr/testify/assert"
Expand All @@ -39,9 +39,11 @@ import (
sideromanager "github.com/siderolabs/omni/internal/pkg/siderolink"
)

//nolint:govet
type fakeWireguardHandler struct {
logger *zap.Logger
loggerMu sync.Mutex
logger *zap.Logger
loggerMu sync.Mutex
peerEventFailFor map[string]struct{} // public key for which to fail the peer event
}

func (h *fakeWireguardHandler) SetupDevice(netip.Prefix, wgtypes.Key, string, *zap.Logger) error {
Expand All @@ -68,6 +70,10 @@ func (h *fakeWireguardHandler) PeerEvent(_ context.Context, spec *specs.Sideroli
h.loggerMu.Lock()
defer h.loggerMu.Unlock()

if _, ok := h.peerEventFailFor[spec.NodePublicKey]; ok {
return errors.New("peer event failed")
}

msg := "updated peer"
if deleted {
msg = "removed peer"
Expand Down Expand Up @@ -108,7 +114,19 @@ func (suite *SiderolinkSuite) SetupTest() {

var err error

suite.manager, err = sideromanager.NewManager(suite.ctx, suite.state, &fakeWireguardHandler{}, params, zaptest.NewLogger(suite.T()), nil, nil)
suite.manager, err = sideromanager.NewManager(
suite.ctx,
suite.state,
&fakeWireguardHandler{
peerEventFailFor: map[string]struct{}{
peerEvenFailWGKey.PublicKey().String(): {},
},
},
params,
zaptest.NewLogger(suite.T()),
nil,
nil,
)
suite.Require().NoError(err)

suite.startManager(params)
Expand Down Expand Up @@ -237,6 +255,55 @@ func (suite *SiderolinkSuite) TestNodes() {
suite.Require().Equal(privateKey.PublicKey().String(), resource.TypedSpec().Value.NodePublicKey)
}

var peerEvenFailWGKey = ensure.Value(wgtypes.GeneratePrivateKey())

func (suite *SiderolinkSuite) TestPeerEventShouldFail() {
var spec *specs.ConnectionParamsSpec

ctx, cancel := context.WithTimeout(suite.ctx, time.Second*2)
defer cancel()

rtestutils.AssertResources[*siderolink.Config](ctx, suite.T(), suite.state, []string{
siderolink.ConfigID,
}, func(r *siderolink.Config, assertion *assert.Assertions) {
assertion.NotEmpty(r.TypedSpec().Value.JoinToken)
assertion.NotEmpty(r.TypedSpec().Value.PrivateKey)
assertion.NotEmpty(r.TypedSpec().Value.PublicKey)
})

rtestutils.AssertResources[*siderolink.ConnectionParams](ctx, suite.T(), suite.state, []string{
siderolink.ConfigID,
}, func(r *siderolink.ConnectionParams, assertion *assert.Assertions) {
assertion.NotEmpty(r.TypedSpec().Value.Args)
assertion.NotEmpty(r.TypedSpec().Value.ApiEndpoint)
assertion.NotEmpty(r.TypedSpec().Value.JoinToken)
assertion.NotEmpty(r.TypedSpec().Value.WireguardEndpoint)

spec = r.TypedSpec().Value
})

conn, err := grpc.DialContext(suite.ctx, suite.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
suite.Require().NoError(err)

client := pb.NewProvisionServiceClient(conn)

_, err = client.Provision(suite.ctx, &pb.ProvisionRequest{
NodeUuid: "testnode",
NodePublicKey: peerEvenFailWGKey.PublicKey().String(),
JoinToken: &spec.JoinToken,
})

suite.Require().Error(err)

_, err = client.Provision(suite.ctx, &pb.ProvisionRequest{
NodeUuid: "testnode",
NodePublicKey: peerEvenFailWGKey.PublicKey().String(),
JoinToken: &spec.JoinToken,
})

suite.Assert().Error(err)
}

func (suite *SiderolinkSuite) TestGenerateJoinToken() {
token, err := sideromanager.GenerateJoinToken()

Expand All @@ -262,13 +329,5 @@ func TestSiderolinkSuite(t *testing.T) {
func safeLock(mx sync.Locker) func() {
mx.Lock()

var locked atomic.Bool

locked.Store(true)

return func() {
if locked.Swap(false) {
mx.Unlock()
}
}
return sync.OnceFunc(mx.Unlock)
}

0 comments on commit 4563bf2

Please sign in to comment.