diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go index 201f36df..246342c9 100644 --- a/roomserver/acls/acls.go +++ b/roomserver/acls/acls.go @@ -14,9 +14,9 @@ import ( "regexp" "strings" "sync" + "time" "github.com/element-hq/dendrite/roomserver/storage/tables" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" ) @@ -27,9 +27,8 @@ type ServerACLDatabase interface { // RoomsWithACLs returns all room IDs for rooms with ACLs RoomsWithACLs(ctx context.Context) ([]string, error) - // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. - // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. - GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) + // GetBulkStateACLs returns all server ACLs for the given rooms. + GetBulkStateACLs(ctx context.Context, roomIDs []string) ([]tables.StrippedEvent, error) } type ServerACLs struct { @@ -40,6 +39,16 @@ type ServerACLs struct { } func NewServerACLs(db ServerACLDatabase) *ServerACLs { + // Add some logging, as this can take a while on larger instances. + logrus.Infof("Loading server ACLs...") + start := time.Now() + aclCount := 0 + defer func() { + logrus.WithFields(logrus.Fields{ + "duration": time.Since(start), + "acls": aclCount, + }).Info("Finished loading server ACLs") + }() ctx := context.TODO() acls := &ServerACLs{ acls: make(map[string]*serverACL), @@ -48,20 +57,25 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs { aclRegexCache: make(map[string]**regexp.Regexp, 100), } - // Look up all of the rooms that the current state server knows about. + // Look up all rooms with ACLs. rooms, err := db.RoomsWithACLs(ctx) if err != nil { logrus.WithError(err).Fatalf("Failed to get known rooms") } - // For each room, let's see if we have a server ACL state event. If we - // do then we'll process it into memory so that we have the regexes to - // hand. - events, err := db.GetBulkStateContent(ctx, rooms, []gomatrixserverlib.StateKeyTuple{{EventType: MRoomServerACL, StateKey: ""}}, false) + // No rooms with ACLs, don't bother hitting the DB again. + if len(rooms) == 0 { + return acls + } + + // Get ACLs for the required rooms, bail if we are unable to get them. + events, err := db.GetBulkStateACLs(ctx, rooms) if err != nil { - logrus.WithError(err).Errorf("Failed to get server ACLs for all rooms: %q", err) + logrus.WithError(err).Fatal("Failed to get server ACLs for all rooms") } + aclCount = len(events) + for _, event := range events { acls.OnServerACLUpdate(event) } diff --git a/roomserver/acls/acls_test.go b/roomserver/acls/acls_test.go index 16f3887d..d5a36a61 100644 --- a/roomserver/acls/acls_test.go +++ b/roomserver/acls/acls_test.go @@ -12,7 +12,6 @@ import ( "testing" "github.com/element-hq/dendrite/roomserver/storage/tables" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" ) @@ -108,11 +107,11 @@ var ( type dummyACLDB struct{} -func (d dummyACLDB) RoomsWithACLs(ctx context.Context) ([]string, error) { +func (d dummyACLDB) RoomsWithACLs(_ context.Context) ([]string, error) { return []string{"1", "2"}, nil } -func (d dummyACLDB) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) { +func (d dummyACLDB) GetBulkStateACLs(_ context.Context, _ []string) ([]tables.StrippedEvent, error) { return []tables.StrippedEvent{ { RoomID: "1", diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 49086dba..3bdeeef8 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -187,6 +187,8 @@ type Database interface { // RoomsWithACLs returns all room IDs for rooms with ACLs RoomsWithACLs(ctx context.Context) ([]string, error) + // GetBulkStateACLs returns all server ACLs for the given rooms. + GetBulkStateACLs(ctx context.Context, roomIDs []string) ([]tables.StrippedEvent, error) QueryAdminEventReports(ctx context.Context, from uint64, limit uint64, backwards bool, userID string, roomID string) ([]api.QueryAdminEventReportsResponse, int64, error) QueryAdminEventReport(ctx context.Context, reportID uint64) (api.QueryAdminEventReportResponse, error) AdminDeleteEventReport(ctx context.Context, reportID uint64) error diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index a4eb0eb9..ac5f54ce 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1437,6 +1437,63 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID spec.UserID, return roomIDs, nil } +// GetBulkStateACLs is a lighter weight form of GetBulkStateContent, which only returns ACL state events. +func (d *Database) GetBulkStateACLs(ctx context.Context, roomIDs []string) ([]tables.StrippedEvent, error) { + tuples := []gomatrixserverlib.StateKeyTuple{{EventType: "m.room.server_acl", StateKey: ""}} + + var eventNIDs []types.EventNID + eventNIDToVer := make(map[types.EventNID]gomatrixserverlib.RoomVersion) + // TODO: This feels like this is going to be really slow... + for _, roomID := range roomIDs { + roomInfo, err2 := d.roomInfo(ctx, nil, roomID) + if err2 != nil { + return nil, fmt.Errorf("GetBulkStateACLs: failed to load room info for room %s : %w", roomID, err2) + } + // for unknown rooms or rooms which we don't have the current state, skip them. + if roomInfo == nil || roomInfo.IsStub() { + continue + } + // No querier needed, as we don't actually do state resolution + stateRes := state.NewStateResolution(d, roomInfo, nil) + entries, err2 := stateRes.LoadStateAtSnapshotForStringTuples(ctx, roomInfo.StateSnapshotNID(), tuples) + if err2 != nil { + return nil, fmt.Errorf("GetBulkStateACLs: failed to load state for room %s : %w", roomID, err2) + } + for _, entry := range entries { + eventNIDs = append(eventNIDs, entry.EventNID) + eventNIDToVer[entry.EventNID] = roomInfo.RoomVersion + } + } + eventIDs, err := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs) + if err != nil { + eventIDs = map[types.EventNID]string{} + } + events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs) + if err != nil { + return nil, fmt.Errorf("GetBulkStateACLs: failed to load event JSON for event nids: %w", err) + } + result := make([]tables.StrippedEvent, len(events)) + for i := range events { + roomVer := eventNIDToVer[events[i].EventNID] + verImpl, err := gomatrixserverlib.GetRoomVersion(roomVer) + if err != nil { + return nil, err + } + ev, err := verImpl.NewEventFromTrustedJSONWithEventID(eventIDs[events[i].EventNID], events[i].EventJSON, false) + if err != nil { + return nil, fmt.Errorf("GetBulkStateACLs: failed to load event JSON for event NID %v : %w", events[i].EventNID, err) + } + result[i] = tables.StrippedEvent{ + EventType: ev.Type(), + RoomID: ev.RoomID().String(), + StateKey: *ev.StateKey(), + ContentValue: tables.ExtractContentValue(&types.HeaderedEvent{PDU: ev}), + } + } + + return result, nil +} + // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) { @@ -1487,6 +1544,9 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu if roomInfo == nil || roomInfo.IsStub() { continue } + // TODO: This is inefficient as we're loading the _entire_ state, but only care about a subset of it. + // This is why GetBulkStateACLs exists. LoadStateAtSnapshotForStringTuples only loads the state we care about, + // but is unfortunately not able to load wildcard state keys. entries, err2 := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID()) if err2 != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to load state for room %s : %w", roomID, err2)