diff --git a/derp/derp_server.go b/derp/derp_server.go index c033e42e7856e..b7cdf252471d8 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -155,7 +155,7 @@ type Server struct { mu sync.Mutex closed bool netConns map[Conn]chan struct{} // chan is closed when conn closes - clients map[key.NodePublic]clientSet + clients map[key.NodePublic]*clientSet watchers set.Set[*sclient] // mesh peers // clientsMesh tracks all clients in the cluster, both locally // and to mesh peers. If the value is nil, that means the @@ -177,8 +177,6 @@ type Server struct { // clientSet represents 1 or more *sclients. // -// The two implementations are singleClient and *dupClientSet. -// // In the common case, client should only have one connection to the // DERP server for a given key. When they're connected multiple times, // we record their set of connections in dupClientSet and keep their @@ -194,26 +192,49 @@ type Server struct { // "health_error" frame to them that'll communicate to the end users // that they cloned a device key, and we'll also surface it in the // admin panel, etc. -type clientSet interface { - // ActiveClient returns the most recently added client to - // the set, as long as it hasn't been disabled, in which - // case it returns nil. - ActiveClient() *sclient - - // Len returns the number of clients in the set. - Len() int - - // ForeachClient calls f for each client in the set. - ForeachClient(f func(*sclient)) +type clientSet struct { + // activeClient holds the currently active connection for the set. It's nil + // if there are no connections or the connection is disabled. + // + // A pointer to a clientSet can be held by peers for long periods of time + // without holding Server.mu to avoid mutex contention on Server.mu, only + // re-acquiring the mutex and checking the clients map if activeClient is + // nil. + activeClient atomic.Pointer[sclient] + + // dup is non-nil if there are multiple connections for the + // public key. It's nil in the common case of only one + // client being connected. + // + // dup is guarded by Server.mu. + dup *dupClientSet } -// singleClient is a clientSet of a single connection. -// This is the common case. -type singleClient struct{ c *sclient } +// Len returns the number of clients in s, which can be +// 0, 1 (the common case), or more (for buggy or transiently +// reconnecting clients). +func (s *clientSet) Len() int { + if s.dup != nil { + return len(s.dup.set) + } + if s.activeClient.Load() != nil { + return 1 + } + return 0 +} -func (s singleClient) ActiveClient() *sclient { return s.c } -func (s singleClient) Len() int { return 1 } -func (s singleClient) ForeachClient(f func(*sclient)) { f(s.c) } +// ForeachClient calls f for each client in the set. +// +// The Server.mu must be held. +func (s *clientSet) ForeachClient(f func(*sclient)) { + if s.dup != nil { + for c := range s.dup.set { + f(c) + } + } else if c := s.activeClient.Load(); c != nil { + f(c) + } +} // A dupClientSet is a clientSet of more than 1 connection. // @@ -224,11 +245,12 @@ func (s singleClient) ForeachClient(f func(*sclient)) { f(s.c) } // // All fields are guarded by Server.mu. type dupClientSet struct { - // set is the set of connected clients for sclient.key. + // set is the set of connected clients for sclient.key, + // including the clientSet's active one. set set.Set[*sclient] // last is the most recent addition to set, or nil if the most - // recent one has since disconnected and nobody else has send + // recent one has since disconnected and nobody else has sent // data since. last *sclient @@ -239,17 +261,15 @@ type dupClientSet struct { sendHistory []*sclient } -func (s *dupClientSet) ActiveClient() *sclient { - if s.last != nil && !s.last.isDisabled.Load() { - return s.last +func (s *clientSet) pickActiveClient() *sclient { + d := s.dup + if d == nil { + return s.activeClient.Load() } - return nil -} -func (s *dupClientSet) Len() int { return len(s.set) } -func (s *dupClientSet) ForeachClient(f func(*sclient)) { - for c := range s.set { - f(c) + if d.last != nil && !d.last.isDisabled.Load() { + return d.last } + return nil } // removeClient removes c from s and reports whether it was in s @@ -317,7 +337,7 @@ func NewServer(privateKey key.NodePrivate, logf logger.Logf) *Server { packetsRecvByKind: metrics.LabelMap{Label: "kind"}, packetsDroppedReason: metrics.LabelMap{Label: "reason"}, packetsDroppedType: metrics.LabelMap{Label: "type"}, - clients: map[key.NodePublic]clientSet{}, + clients: map[key.NodePublic]*clientSet{}, clientsMesh: map[key.NodePublic]PacketForwarder{}, netConns: map[Conn]chan struct{}{}, memSys0: ms.Sys, @@ -444,7 +464,7 @@ func (s *Server) IsClientConnectedForTest(k key.NodePublic) bool { if !ok { return false } - return x.ActiveClient() != nil + return x.activeClient.Load() != nil } // Accept adds a new connection to the server and serves it. @@ -534,37 +554,43 @@ func (s *Server) registerClient(c *sclient) { s.mu.Lock() defer s.mu.Unlock() - curSet := s.clients[c.key] - switch curSet := curSet.(type) { - case nil: - s.clients[c.key] = singleClient{c} + cs, ok := s.clients[c.key] + if !ok { c.debugLogf("register single client") - case singleClient: + cs = &clientSet{} + s.clients[c.key] = cs + } + was := cs.activeClient.Load() + if was == nil { + // Common case. + } else { + was.isDup.Store(true) + c.isDup.Store(true) + } + + dup := cs.dup + if dup == nil && was != nil { s.dupClientKeys.Add(1) s.dupClientConns.Add(2) // both old and new count s.dupClientConnTotal.Add(1) - old := curSet.ActiveClient() - old.isDup.Store(true) - c.isDup.Store(true) - s.clients[c.key] = &dupClientSet{ - last: c, - set: set.Set[*sclient]{ - old: struct{}{}, - c: struct{}{}, - }, - sendHistory: []*sclient{old}, + dup = &dupClientSet{ + set: set.Of(c, was), + last: c, + sendHistory: []*sclient{was}, } + cs.dup = dup c.debugLogf("register duplicate client") - case *dupClientSet: + } else if dup != nil { s.dupClientConns.Add(1) // the gauge s.dupClientConnTotal.Add(1) // the counter - c.isDup.Store(true) - curSet.set.Add(c) - curSet.last = c - curSet.sendHistory = append(curSet.sendHistory, c) + dup.set.Add(c) + dup.last = c + dup.sendHistory = append(dup.sendHistory, c) c.debugLogf("register another duplicate client") } + cs.activeClient.Store(c) + if _, ok := s.clientsMesh[c.key]; !ok { s.clientsMesh[c.key] = nil // just for varz of total users in cluster } @@ -595,30 +621,47 @@ func (s *Server) unregisterClient(c *sclient) { s.mu.Lock() defer s.mu.Unlock() - set := s.clients[c.key] - switch set := set.(type) { - case nil: + set, ok := s.clients[c.key] + if !ok { c.logf("[unexpected]; clients map is empty") - case singleClient: + return + } + + dup := set.dup + if dup == nil { + // The common case. + cur := set.activeClient.Load() + if cur == nil { + c.logf("[unexpected]; active client is nil") + return + } + if cur != c { + c.logf("[unexpected]; active client is not c") + return + } c.debugLogf("removed connection") + set.activeClient.Store(nil) delete(s.clients, c.key) if v, ok := s.clientsMesh[c.key]; ok && v == nil { delete(s.clientsMesh, c.key) s.notePeerGoneFromRegionLocked(c.key) } s.broadcastPeerStateChangeLocked(c.key, netip.AddrPort{}, 0, false) - case *dupClientSet: + } else { c.debugLogf("removed duplicate client") - if set.removeClient(c) { + if dup.removeClient(c) { s.dupClientConns.Add(-1) } else { c.logf("[unexpected]; dup client set didn't shrink") } - if set.Len() == 1 { + if dup.set.Len() == 1 { + // If we drop down to one connection, demote it down + // to a regular single client (a nil dup set). + set.dup = nil s.dupClientConns.Add(-1) // again; for the original one's s.dupClientKeys.Add(-1) var remain *sclient - for remain = range set.set { + for remain = range dup.set { break } if remain == nil { @@ -626,7 +669,10 @@ func (s *Server) unregisterClient(c *sclient) { } remain.isDisabled.Store(false) remain.isDup.Store(false) - s.clients[c.key] = singleClient{remain} + set.activeClient.Store(remain) + } else { + // Still a duplicate. Pick a winner. + set.activeClient.Store(set.pickActiveClient()) } } @@ -697,7 +743,7 @@ func (s *Server) addWatcher(c *sclient) { // Queue messages for each already-connected client. for peer, clientSet := range s.clients { - ac := clientSet.ActiveClient() + ac := clientSet.activeClient.Load() if ac == nil { continue } @@ -955,7 +1001,7 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error { s.mu.Lock() if set, ok := s.clients[dstKey]; ok { dstLen = set.Len() - dst = set.ActiveClient() + dst = set.activeClient.Load() } if dst != nil { s.notePeerSendLocked(srcKey, dst) @@ -1010,7 +1056,7 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error { s.mu.Lock() if set, ok := s.clients[dstKey]; ok { dstLen = set.Len() - dst = set.ActiveClient() + dst = set.activeClient.Load() } if dst != nil { s.notePeerSendLocked(c.key, dst) @@ -1256,22 +1302,28 @@ func (s *Server) noteClientActivity(c *sclient) { s.mu.Lock() defer s.mu.Unlock() - ds, ok := s.clients[c.key].(*dupClientSet) + cs, ok := s.clients[c.key] if !ok { + return + } + dup := cs.dup + if dup == nil { // It became unduped in between the isDup fast path check above // and the mutex check. Nothing to do. return } if s.dupPolicy == lastWriterIsActive { - ds.last = c - } else if ds.last == nil { + dup.last = c + cs.activeClient.Store(c) + } else if dup.last == nil { // If we didn't have a primary, let the current // speaker be the primary. - ds.last = c + dup.last = c + cs.activeClient.Store(c) } - if sh := ds.sendHistory; len(sh) != 0 && sh[len(sh)-1] == c { + if sh := dup.sendHistory; len(sh) != 0 && sh[len(sh)-1] == c { // The client c was the last client to make activity // in this set and it was already recorded. Nothing to // do. @@ -1281,10 +1333,13 @@ func (s *Server) noteClientActivity(c *sclient) { // If we saw this connection send previously, then consider // the group fighting and disable them all. if s.dupPolicy == disableFighters { - for _, prior := range ds.sendHistory { + for _, prior := range dup.sendHistory { if prior == c { - ds.ForeachClient(func(c *sclient) { + cs.ForeachClient(func(c *sclient) { c.isDisabled.Store(true) + if cs.activeClient.Load() == c { + cs.activeClient.Store(nil) + } }) break } @@ -1292,7 +1347,7 @@ func (s *Server) noteClientActivity(c *sclient) { } // Append this client to the list of clients who spoke last. - ds.sendHistory = append(ds.sendHistory, c) + dup.sendHistory = append(dup.sendHistory, c) } type serverInfo struct { @@ -1407,6 +1462,11 @@ func (s *Server) recvForwardPacket(br *bufio.Reader, frameLen uint32) (srcKey, d // sclient is a client connection to the server. // +// A node (a wireguard public key) can be connected multiple times to a DERP server +// and thus have multiple sclient instances. An sclient represents +// only one of these possibly multiple connections. See clientSet for the +// type that represents the set of all connections for a given key. +// // (The "s" prefix is to more explicitly distinguish it from Client in derp_client.go) type sclient struct { // Static after construction. diff --git a/derp/derp_test.go b/derp/derp_test.go index dde2054e65fcd..72de265529ad1 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -731,7 +731,7 @@ func pubAll(b byte) (ret key.NodePublic) { func TestForwarderRegistration(t *testing.T) { s := &Server{ - clients: make(map[key.NodePublic]clientSet), + clients: make(map[key.NodePublic]*clientSet), clientsMesh: map[key.NodePublic]PacketForwarder{}, } want := func(want map[key.NodePublic]PacketForwarder) { @@ -746,6 +746,11 @@ func TestForwarderRegistration(t *testing.T) { t.Errorf("counter = %v; want %v", got, want) } } + singleClient := func(c *sclient) *clientSet { + cs := &clientSet{} + cs.activeClient.Store(c) + return cs + } u1 := pubAll(1) u2 := pubAll(2) @@ -808,7 +813,7 @@ func TestForwarderRegistration(t *testing.T) { key: u1, logf: logger.Discard, } - s.clients[u1] = singleClient{u1c} + s.clients[u1] = singleClient(u1c) s.RemovePacketForwarder(u1, testFwd(100)) want(map[key.NodePublic]PacketForwarder{ u1: nil, @@ -828,7 +833,7 @@ func TestForwarderRegistration(t *testing.T) { // Now pretend u1 was already connected locally (so clientsMesh[u1] is nil), and then we heard // that they're also connected to a peer of ours. That shouldn't transition the forwarder // from nil to the new one, not a multiForwarder. - s.clients[u1] = singleClient{u1c} + s.clients[u1] = singleClient(u1c) s.clientsMesh[u1] = nil want(map[key.NodePublic]PacketForwarder{ u1: nil, @@ -860,7 +865,7 @@ func TestMultiForwarder(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) s := &Server{ - clients: make(map[key.NodePublic]clientSet), + clients: make(map[key.NodePublic]*clientSet), clientsMesh: map[key.NodePublic]PacketForwarder{}, } u := pubAll(1) @@ -1078,43 +1083,48 @@ func TestServerDupClients(t *testing.T) { } wantSingleClient := func(t *testing.T, want *sclient) { t.Helper() - switch s := s.clients[want.key].(type) { - case singleClient: - if s.c != want { - t.Error("wrong single client") - return - } - if want.isDup.Load() { + got, ok := s.clients[want.key] + if !ok { + t.Error("no clients for key") + return + } + if got.dup != nil { + t.Errorf("unexpected dup set for single client") + } + cur := got.activeClient.Load() + if cur != want { + t.Errorf("active client = %q; want %q", clientName[cur], clientName[want]) + } + if cur != nil { + if cur.isDup.Load() { t.Errorf("unexpected isDup on singleClient") } - if want.isDisabled.Load() { + if cur.isDisabled.Load() { t.Errorf("unexpected isDisabled on singleClient") } - case nil: - t.Error("no clients for key") - case *dupClientSet: - t.Error("unexpected multiple clients for key") } } wantNoClient := func(t *testing.T) { t.Helper() - switch s := s.clients[clientPub].(type) { - case nil: - // Good. + _, ok := s.clients[clientPub] + if !ok { + // Good return - default: - t.Errorf("got %T; want empty", s) } + t.Errorf("got client; want empty") } wantDupSet := func(t *testing.T) *dupClientSet { t.Helper() - switch s := s.clients[clientPub].(type) { - case *dupClientSet: - return s - default: - t.Fatalf("wanted dup set; got %T", s) + cs, ok := s.clients[clientPub] + if !ok { + t.Fatal("no set for key; want dup set") return nil } + if cs.dup != nil { + return cs.dup + } + t.Fatalf("no dup set for key; want dup set") + return nil } wantActive := func(t *testing.T, want *sclient) { t.Helper() @@ -1123,7 +1133,7 @@ func TestServerDupClients(t *testing.T) { t.Error("no set for key") return } - got := set.ActiveClient() + got := set.activeClient.Load() if got != want { t.Errorf("active client = %q; want %q", clientName[got], clientName[want]) }