From a9824e78d01c515a5f7144cc644e8c4dad0f50c9 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 31 Dec 2024 21:32:21 +0000 Subject: [PATCH] Refactor some JetStream helper code, add support for specifying JetStream domain Signed-off-by: Neil Alexander --- setup/config/config_jetstream.go | 2 + setup/jetstream/helpers.go | 141 ++++++++++++++++--------------- setup/jetstream/nats.go | 91 +++++++++++--------- 3 files changed, 127 insertions(+), 107 deletions(-) diff --git a/setup/config/config_jetstream.go b/setup/config/config_jetstream.go index a048e4d09..c37f917cc 100644 --- a/setup/config/config_jetstream.go +++ b/setup/config/config_jetstream.go @@ -15,6 +15,8 @@ type JetStream struct { // The prefix to use for stream names for this homeserver - really only // useful if running more than one Dendrite on the same NATS deployment. TopicPrefix string `yaml:"topic_prefix"` + // The JetStream domain, if needed. + JetStreamDomain string `yaml:"js_domain"` // Keep all storage in memory. This is mostly useful for unit tests. InMemory bool `yaml:"in_memory"` // Disable logging. This is mostly useful for unit tests. diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index 533652160..672f3e6ac 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -22,7 +22,7 @@ func JetStreamConsumer( f func(ctx context.Context, msgs []*nats.Msg) bool, opts ...nats.SubOpt, ) error { - defer func() { + defer func(durable string) { // If there are existing consumers from before they were pull // consumers, we need to clean up the old push consumers. However, // in order to not affect the interest-based policies, we need to @@ -33,86 +33,93 @@ func JetStreamConsumer( logrus.WithContext(ctx).Warnf("Failed to clean up old consumer %q", durable) } } - }() + }(durable) - name := durable + "Pull" - sub, err := js.PullSubscribe(subj, name, opts...) + durable = durable + "Pull" + sub, err := js.PullSubscribe(subj, durable, opts...) if err != nil { sentry.CaptureException(err) - return fmt.Errorf("nats.SubscribeSync: %w", err) + logrus.WithContext(ctx).WithError(err).Warnf("Failed to configure durable %q", durable) + return err } - go func() { - for { - // If the parent context has given up then there's no point in - // carrying on doing anything, so stop the listener. - select { - case <-ctx.Done(): - if err := sub.Unsubscribe(); err != nil { - logrus.WithContext(ctx).Warnf("Failed to unsubscribe %q", durable) - } - return - default: - } - // The context behaviour here is surprising — we supply a context - // so that we can interrupt the fetch if we want, but NATS will still - // enforce its own deadline (roughly 5 seconds by default). Therefore - // it is our responsibility to check whether our context expired or - // not when a context error is returned. Footguns. Footguns everywhere. - msgs, err := sub.Fetch(batch, nats.Context(ctx)) - if err != nil { - if err == context.Canceled || err == context.DeadlineExceeded { - // Work out whether it was the JetStream context that expired - // or whether it was our supplied context. - select { - case <-ctx.Done(): - // The supplied context expired, so we want to stop the - // consumer altogether. - return - default: - // The JetStream context expired, so the fetch probably - // just timed out and we should try again. - continue - } - } else if errors.Is(err, nats.ErrConsumerDeleted) { - // The consumer was deleted so stop. + go jetStreamConsumerWorker(ctx, sub, subj, batch, f) + return nil +} + +func jetStreamConsumerWorker( + ctx context.Context, sub *nats.Subscription, subj string, batch int, + f func(ctx context.Context, msgs []*nats.Msg) bool, +) { + for { + // If the parent context has given up then there's no point in + // carrying on doing anything, so stop the listener. + select { + case <-ctx.Done(): + return + default: + } + // The context behaviour here is surprising — we supply a context + // so that we can interrupt the fetch if we want, but NATS will still + // enforce its own deadline (roughly 5 seconds by default). Therefore + // it is our responsibility to check whether our context expired or + // not when a context error is returned. Footguns. Footguns everywhere. + msgs, err := sub.Fetch(batch, nats.Context(ctx)) + if err != nil { + if err == context.Canceled || err == context.DeadlineExceeded { + // Work out whether it was the JetStream context that expired + // or whether it was our supplied context. + select { + case <-ctx.Done(): + // The supplied context expired, so we want to stop the + // consumer altogether. return - } else { - // Unfortunately, there's no ErrServerShutdown or similar, so we need to compare the string - if err.Error() == "nats: Server Shutdown" { - logrus.WithContext(ctx).Warn("nats server shutting down") - return - } - // Something else went wrong, so we'll panic. - sentry.CaptureException(err) - logrus.WithContext(ctx).WithField("subject", subj).Fatal(err) + default: + // The JetStream context expired, so the fetch probably + // just timed out and we should try again. + continue } + } else if errors.Is(err, nats.ErrTimeout) { + // Pull request was invalidated, try again. + continue + } else if errors.Is(err, nats.ErrConsumerLeadershipChanged) { + // Leadership changed so pending pull requests became invalidated, + // just try again. + continue + } else if err.Error() == "nats: Server Shutdown" { + // The server is shutting down, but we'll rely on reconnect + // behaviour to try and either connect us to another node (if + // clustered) or to reconnect when the server comes back up. + continue + } else { + // Something else went wrong. + logrus.WithContext(ctx).WithField("subject", subj).WithError(err).Warn("Error on pull subscriber fetch") + return } - if len(msgs) < 1 { + } + if len(msgs) < 1 { + continue + } + for _, msg := range msgs { + if err = msg.InProgress(nats.Context(ctx)); err != nil { + logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err)) + sentry.CaptureException(err) continue } + } + if f(ctx, msgs) { for _, msg := range msgs { - if err = msg.InProgress(nats.Context(ctx)); err != nil { - logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err)) + if err = msg.AckSync(nats.Context(ctx)); err != nil { + logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err)) sentry.CaptureException(err) - continue } } - if f(ctx, msgs) { - for _, msg := range msgs { - if err = msg.AckSync(nats.Context(ctx)); err != nil { - logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err)) - sentry.CaptureException(err) - } - } - } else { - for _, msg := range msgs { - if err = msg.Nak(nats.Context(ctx)); err != nil { - logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err)) - sentry.CaptureException(err) - } + } else { + for _, msg := range msgs { + if err = msg.Nak(nats.Context(ctx)); err != nil { + logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err)) + sentry.CaptureException(err) } } } - }() - return nil + } } diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index cc896f8ee..12dc97a99 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -8,7 +8,6 @@ import ( "sync" "time" - "github.com/getsentry/sentry-go" "github.com/sirupsen/logrus" "github.com/element-hq/dendrite/setup/config" @@ -36,17 +35,20 @@ func DeleteAllStreams(js natsclient.JetStreamContext, cfg *config.JetStream) { func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient.JetStreamContext, *natsclient.Conn) { natsLock.Lock() defer natsLock.Unlock() - // check if we need an in-process NATS Server - if len(cfg.Addresses) != 0 { - // reuse existing connections - if s.nc != nil { - return s.js, s.nc - } + var err error + + // If an existing connection exists, return it. + if s.nc != nil && s.js != nil { + return s.js, s.nc + } + + // For connecting to an external NATS server. + if len(cfg.Addresses) > 0 { s.js, s.nc = setupNATS(process, cfg, nil) return s.js, s.nc } - if s.Server == nil { - var err error + + if len(cfg.Addresses) == 0 && s.Server == nil { opts := &natsserver.Options{ ServerName: "monolith", DontListen: true, @@ -58,8 +60,7 @@ func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetS NoLog: cfg.NoLog, SyncAlways: true, } - s.Server, err = natsserver.NewServer(opts) - if err != nil { + if s.Server, err = natsserver.NewServer(opts); err != nil { panic(err) } if !cfg.NoLog { @@ -75,37 +76,47 @@ func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetS s.WaitForShutdown() process.ComponentFinished() }() + if !s.ReadyForConnections(time.Second * 60) { + logrus.Fatalln("NATS did not start in time") + } } - if !s.ReadyForConnections(time.Second * 60) { - logrus.Fatalln("NATS did not start in time") - } - // reuse existing connections - if s.nc != nil { - return s.js, s.nc - } - nc, err := natsclient.Connect("", natsclient.InProcessServer(s)) - if err != nil { + + // No existing process connection, create a new one. + if s.nc, err = natsclient.Connect("", natsclient.InProcessServer(s.Server)); err != nil { logrus.Fatalln("Failed to create NATS client") } - js, _ := setupNATS(process, cfg, nc) - s.js = js - s.nc = nc - return js, nc + s.js, s.nc = setupNATS(process, cfg, s.nc) + return s.js, s.nc } // nolint:gocyclo func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsclient.Conn) (natsclient.JetStreamContext, *natsclient.Conn) { + jsOpts := []natsclient.JSOpt{} + if cfg.JetStreamDomain != "" { + jsOpts = append(jsOpts, natsclient.Domain(cfg.JetStreamDomain)) + } + if nc == nil { var err error - opts := []natsclient.Option{} + opts := []natsclient.Option{ + natsclient.Name("Dendrite"), + natsclient.MaxReconnects(-1), // Try forever + natsclient.ReconnectJitter(time.Second, time.Second), + natsclient.ReconnectWait(time.Second * 10), + natsclient.ReconnectHandler(func(c *natsclient.Conn) { + js, jerr := c.JetStream(jsOpts...) + if jerr != nil { + logrus.WithError(jerr).Panic("Unable to get JetStream context in reconnect handler") + return + } + checkAndConfigureStreams(process, cfg, js) + }), + } if cfg.DisableTLSValidation { opts = append(opts, natsclient.Secure(&tls.Config{ InsecureSkipVerify: true, })) } - if string(cfg.Credentials) != "" { - opts = append(opts, natsclient.UserCredentials(string(cfg.Credentials))) - } nc, err = natsclient.Connect(strings.Join(cfg.Addresses, ","), opts...) if err != nil { logrus.WithError(err).Panic("Unable to connect to NATS") @@ -113,15 +124,19 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc } } - s, err := nc.JetStream() + js, err := nc.JetStream(jsOpts...) if err != nil { logrus.WithError(err).Panic("Unable to get JetStream context") return nil, nil } + checkAndConfigureStreams(process, cfg, js) + return js, nc +} +func checkAndConfigureStreams(process *process.ProcessContext, cfg *config.JetStream, js natsclient.JetStreamContext) { for _, stream := range streams { // streams are defined in streams.go name := cfg.Prefixed(stream.Name) - info, err := s.StreamInfo(name) + info, err := js.StreamInfo(name) if err != nil && err != natsclient.ErrStreamNotFound { logrus.WithError(err).Fatal("Unable to get stream info") } @@ -153,11 +168,11 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc case info.Config.MaxAge != stream.MaxAge: // Try updating the stream first, as many things can be updated // non-destructively. - if info, err = s.UpdateStream(stream); err != nil { + if info, err = js.UpdateStream(stream); err != nil { logrus.WithError(err).Warnf("Unable to update stream %q, recreating...", name) // We failed to update the stream, this is a last attempt to get // things working but may result in data loss. - if err = s.DeleteStream(name); err != nil { + if err = js.DeleteStream(name); err != nil { logrus.WithError(err).Fatalf("Unable to delete stream %q", name) } info = nil @@ -176,7 +191,7 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc namespaced := *stream namespaced.Name = name namespaced.Subjects = subjects - if _, err = s.AddStream(&namespaced); err != nil { + if _, err = js.AddStream(&namespaced); err != nil { logger := logrus.WithError(err).WithFields(logrus.Fields{ "stream": namespaced.Name, "subjects": namespaced.Subjects, @@ -193,10 +208,9 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc // we can't recover anything that was queued on the disk but we // will still be able to start and run hopefully in the meantime. logger.WithError(err).Error("Unable to add stream") - sentry.CaptureException(fmt.Errorf("Unable to add stream %q: %w", namespaced.Name, err)) namespaced.Storage = natsclient.MemoryStorage - if _, err = s.AddStream(&namespaced); err != nil { + if _, err = js.AddStream(&namespaced); err != nil { // We tried to add the stream in-memory instead but something // went wrong. That's an unrecoverable situation so we will // give up at this point. @@ -208,7 +222,6 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc // disk will be left alone, but our ability to recover from a // future crash will be limited. Yell about it. err := fmt.Errorf("Stream %q is running in-memory; this may be due to data corruption in the JetStream storage directory", namespaced.Name) - sentry.CaptureException(err) process.Degraded(err) } } @@ -229,15 +242,13 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc streamName := cfg.Matrix.JetStream.Prefixed(stream) for _, consumer := range consumers { consumerName := cfg.Matrix.JetStream.Prefixed(consumer) + "Pull" - consumerInfo, err := s.ConsumerInfo(streamName, consumerName) + consumerInfo, err := js.ConsumerInfo(streamName, consumerName) if err != nil || consumerInfo == nil { continue } - if err = s.DeleteConsumer(streamName, consumerName); err != nil { + if err = js.DeleteConsumer(streamName, consumerName); err != nil { logrus.WithError(err).Errorf("Unable to clean up old consumer %q for stream %q", consumer, stream) } } } - - return s, nc }