Skip to content

Commit

Permalink
Refactor some JetStream helper code, add support for specifying JetSt…
Browse files Browse the repository at this point in the history
…ream domain

Signed-off-by: Neil Alexander <[email protected]>
  • Loading branch information
neilalexander committed Dec 31, 2024
1 parent add73ec commit a9824e7
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 107 deletions.
2 changes: 2 additions & 0 deletions setup/config/config_jetstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ type JetStream struct {
// The prefix to use for stream names for this homeserver - really only
// useful if running more than one Dendrite on the same NATS deployment.
TopicPrefix string `yaml:"topic_prefix"`
// The JetStream domain, if needed.
JetStreamDomain string `yaml:"js_domain"`
// Keep all storage in memory. This is mostly useful for unit tests.
InMemory bool `yaml:"in_memory"`
// Disable logging. This is mostly useful for unit tests.
Expand Down
141 changes: 74 additions & 67 deletions setup/jetstream/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func JetStreamConsumer(
f func(ctx context.Context, msgs []*nats.Msg) bool,
opts ...nats.SubOpt,
) error {
defer func() {
defer func(durable string) {
// If there are existing consumers from before they were pull
// consumers, we need to clean up the old push consumers. However,
// in order to not affect the interest-based policies, we need to
Expand All @@ -33,86 +33,93 @@ func JetStreamConsumer(
logrus.WithContext(ctx).Warnf("Failed to clean up old consumer %q", durable)
}
}
}()
}(durable)

name := durable + "Pull"
sub, err := js.PullSubscribe(subj, name, opts...)
durable = durable + "Pull"
sub, err := js.PullSubscribe(subj, durable, opts...)
if err != nil {
sentry.CaptureException(err)
return fmt.Errorf("nats.SubscribeSync: %w", err)
logrus.WithContext(ctx).WithError(err).Warnf("Failed to configure durable %q", durable)
return err
}
go func() {
for {
// If the parent context has given up then there's no point in
// carrying on doing anything, so stop the listener.
select {
case <-ctx.Done():
if err := sub.Unsubscribe(); err != nil {
logrus.WithContext(ctx).Warnf("Failed to unsubscribe %q", durable)
}
return
default:
}
// The context behaviour here is surprising — we supply a context
// so that we can interrupt the fetch if we want, but NATS will still
// enforce its own deadline (roughly 5 seconds by default). Therefore
// it is our responsibility to check whether our context expired or
// not when a context error is returned. Footguns. Footguns everywhere.
msgs, err := sub.Fetch(batch, nats.Context(ctx))
if err != nil {
if err == context.Canceled || err == context.DeadlineExceeded {
// Work out whether it was the JetStream context that expired
// or whether it was our supplied context.
select {
case <-ctx.Done():
// The supplied context expired, so we want to stop the
// consumer altogether.
return
default:
// The JetStream context expired, so the fetch probably
// just timed out and we should try again.
continue
}
} else if errors.Is(err, nats.ErrConsumerDeleted) {
// The consumer was deleted so stop.
go jetStreamConsumerWorker(ctx, sub, subj, batch, f)
return nil
}

func jetStreamConsumerWorker(
ctx context.Context, sub *nats.Subscription, subj string, batch int,
f func(ctx context.Context, msgs []*nats.Msg) bool,
) {
for {
// If the parent context has given up then there's no point in
// carrying on doing anything, so stop the listener.
select {
case <-ctx.Done():
return
default:
}
// The context behaviour here is surprising — we supply a context
// so that we can interrupt the fetch if we want, but NATS will still
// enforce its own deadline (roughly 5 seconds by default). Therefore
// it is our responsibility to check whether our context expired or
// not when a context error is returned. Footguns. Footguns everywhere.
msgs, err := sub.Fetch(batch, nats.Context(ctx))
if err != nil {
if err == context.Canceled || err == context.DeadlineExceeded {
// Work out whether it was the JetStream context that expired
// or whether it was our supplied context.
select {
case <-ctx.Done():
// The supplied context expired, so we want to stop the
// consumer altogether.
return
} else {
// Unfortunately, there's no ErrServerShutdown or similar, so we need to compare the string
if err.Error() == "nats: Server Shutdown" {
logrus.WithContext(ctx).Warn("nats server shutting down")
return
}
// Something else went wrong, so we'll panic.
sentry.CaptureException(err)
logrus.WithContext(ctx).WithField("subject", subj).Fatal(err)
default:
// The JetStream context expired, so the fetch probably
// just timed out and we should try again.
continue
}
} else if errors.Is(err, nats.ErrTimeout) {
// Pull request was invalidated, try again.
continue
} else if errors.Is(err, nats.ErrConsumerLeadershipChanged) {
// Leadership changed so pending pull requests became invalidated,
// just try again.
continue
} else if err.Error() == "nats: Server Shutdown" {
// The server is shutting down, but we'll rely on reconnect
// behaviour to try and either connect us to another node (if
// clustered) or to reconnect when the server comes back up.
continue
} else {
// Something else went wrong.
logrus.WithContext(ctx).WithField("subject", subj).WithError(err).Warn("Error on pull subscriber fetch")
return
}
if len(msgs) < 1 {
}
if len(msgs) < 1 {
continue
}
for _, msg := range msgs {
if err = msg.InProgress(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err))
sentry.CaptureException(err)
continue
}
}
if f(ctx, msgs) {
for _, msg := range msgs {
if err = msg.InProgress(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.InProgress: %w", err))
if err = msg.AckSync(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err))
sentry.CaptureException(err)
continue
}
}
if f(ctx, msgs) {
for _, msg := range msgs {
if err = msg.AckSync(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.AckSync: %w", err))
sentry.CaptureException(err)
}
}
} else {
for _, msg := range msgs {
if err = msg.Nak(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err))
sentry.CaptureException(err)
}
} else {
for _, msg := range msgs {
if err = msg.Nak(nats.Context(ctx)); err != nil {
logrus.WithContext(ctx).WithField("subject", subj).Warn(fmt.Errorf("msg.Nak: %w", err))
sentry.CaptureException(err)
}
}
}
}()
return nil
}
}
91 changes: 51 additions & 40 deletions setup/jetstream/nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"sync"
"time"

"github.com/getsentry/sentry-go"
"github.com/sirupsen/logrus"

"github.com/element-hq/dendrite/setup/config"
Expand Down Expand Up @@ -36,17 +35,20 @@ func DeleteAllStreams(js natsclient.JetStreamContext, cfg *config.JetStream) {
func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient.JetStreamContext, *natsclient.Conn) {
natsLock.Lock()
defer natsLock.Unlock()
// check if we need an in-process NATS Server
if len(cfg.Addresses) != 0 {
// reuse existing connections
if s.nc != nil {
return s.js, s.nc
}
var err error

// If an existing connection exists, return it.
if s.nc != nil && s.js != nil {
return s.js, s.nc
}

// For connecting to an external NATS server.
if len(cfg.Addresses) > 0 {
s.js, s.nc = setupNATS(process, cfg, nil)
return s.js, s.nc
}
if s.Server == nil {
var err error

if len(cfg.Addresses) == 0 && s.Server == nil {
opts := &natsserver.Options{
ServerName: "monolith",
DontListen: true,
Expand All @@ -58,8 +60,7 @@ func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetS
NoLog: cfg.NoLog,
SyncAlways: true,
}
s.Server, err = natsserver.NewServer(opts)
if err != nil {
if s.Server, err = natsserver.NewServer(opts); err != nil {
panic(err)
}
if !cfg.NoLog {
Expand All @@ -75,53 +76,67 @@ func (s *NATSInstance) Prepare(process *process.ProcessContext, cfg *config.JetS
s.WaitForShutdown()
process.ComponentFinished()
}()
if !s.ReadyForConnections(time.Second * 60) {
logrus.Fatalln("NATS did not start in time")
}
}
if !s.ReadyForConnections(time.Second * 60) {
logrus.Fatalln("NATS did not start in time")
}
// reuse existing connections
if s.nc != nil {
return s.js, s.nc
}
nc, err := natsclient.Connect("", natsclient.InProcessServer(s))
if err != nil {

// No existing process connection, create a new one.
if s.nc, err = natsclient.Connect("", natsclient.InProcessServer(s.Server)); err != nil {
logrus.Fatalln("Failed to create NATS client")
}
js, _ := setupNATS(process, cfg, nc)
s.js = js
s.nc = nc
return js, nc
s.js, s.nc = setupNATS(process, cfg, s.nc)
return s.js, s.nc
}

// nolint:gocyclo
func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsclient.Conn) (natsclient.JetStreamContext, *natsclient.Conn) {
jsOpts := []natsclient.JSOpt{}
if cfg.JetStreamDomain != "" {
jsOpts = append(jsOpts, natsclient.Domain(cfg.JetStreamDomain))
}

if nc == nil {
var err error
opts := []natsclient.Option{}
opts := []natsclient.Option{
natsclient.Name("Dendrite"),
natsclient.MaxReconnects(-1), // Try forever
natsclient.ReconnectJitter(time.Second, time.Second),
natsclient.ReconnectWait(time.Second * 10),
natsclient.ReconnectHandler(func(c *natsclient.Conn) {
js, jerr := c.JetStream(jsOpts...)
if jerr != nil {
logrus.WithError(jerr).Panic("Unable to get JetStream context in reconnect handler")
return
}
checkAndConfigureStreams(process, cfg, js)
}),
}
if cfg.DisableTLSValidation {
opts = append(opts, natsclient.Secure(&tls.Config{
InsecureSkipVerify: true,
}))
}
if string(cfg.Credentials) != "" {
opts = append(opts, natsclient.UserCredentials(string(cfg.Credentials)))
}
nc, err = natsclient.Connect(strings.Join(cfg.Addresses, ","), opts...)
if err != nil {
logrus.WithError(err).Panic("Unable to connect to NATS")
return nil, nil
}
}

s, err := nc.JetStream()
js, err := nc.JetStream(jsOpts...)
if err != nil {
logrus.WithError(err).Panic("Unable to get JetStream context")
return nil, nil
}
checkAndConfigureStreams(process, cfg, js)
return js, nc
}

func checkAndConfigureStreams(process *process.ProcessContext, cfg *config.JetStream, js natsclient.JetStreamContext) {
for _, stream := range streams { // streams are defined in streams.go
name := cfg.Prefixed(stream.Name)
info, err := s.StreamInfo(name)
info, err := js.StreamInfo(name)
if err != nil && err != natsclient.ErrStreamNotFound {
logrus.WithError(err).Fatal("Unable to get stream info")
}
Expand Down Expand Up @@ -153,11 +168,11 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
case info.Config.MaxAge != stream.MaxAge:
// Try updating the stream first, as many things can be updated
// non-destructively.
if info, err = s.UpdateStream(stream); err != nil {
if info, err = js.UpdateStream(stream); err != nil {
logrus.WithError(err).Warnf("Unable to update stream %q, recreating...", name)
// We failed to update the stream, this is a last attempt to get
// things working but may result in data loss.
if err = s.DeleteStream(name); err != nil {
if err = js.DeleteStream(name); err != nil {
logrus.WithError(err).Fatalf("Unable to delete stream %q", name)
}
info = nil
Expand All @@ -176,7 +191,7 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
namespaced := *stream
namespaced.Name = name
namespaced.Subjects = subjects
if _, err = s.AddStream(&namespaced); err != nil {
if _, err = js.AddStream(&namespaced); err != nil {
logger := logrus.WithError(err).WithFields(logrus.Fields{
"stream": namespaced.Name,
"subjects": namespaced.Subjects,
Expand All @@ -193,10 +208,9 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
// we can't recover anything that was queued on the disk but we
// will still be able to start and run hopefully in the meantime.
logger.WithError(err).Error("Unable to add stream")
sentry.CaptureException(fmt.Errorf("Unable to add stream %q: %w", namespaced.Name, err))

namespaced.Storage = natsclient.MemoryStorage
if _, err = s.AddStream(&namespaced); err != nil {
if _, err = js.AddStream(&namespaced); err != nil {
// We tried to add the stream in-memory instead but something
// went wrong. That's an unrecoverable situation so we will
// give up at this point.
Expand All @@ -208,7 +222,6 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
// disk will be left alone, but our ability to recover from a
// future crash will be limited. Yell about it.
err := fmt.Errorf("Stream %q is running in-memory; this may be due to data corruption in the JetStream storage directory", namespaced.Name)
sentry.CaptureException(err)
process.Degraded(err)
}
}
Expand All @@ -229,15 +242,13 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc
streamName := cfg.Matrix.JetStream.Prefixed(stream)
for _, consumer := range consumers {
consumerName := cfg.Matrix.JetStream.Prefixed(consumer) + "Pull"
consumerInfo, err := s.ConsumerInfo(streamName, consumerName)
consumerInfo, err := js.ConsumerInfo(streamName, consumerName)
if err != nil || consumerInfo == nil {
continue
}
if err = s.DeleteConsumer(streamName, consumerName); err != nil {
if err = js.DeleteConsumer(streamName, consumerName); err != nil {
logrus.WithError(err).Errorf("Unable to clean up old consumer %q for stream %q", consumer, stream)
}
}
}

return s, nc
}

0 comments on commit a9824e7

Please sign in to comment.