diff --git a/cmd/gonic/gonic.go b/cmd/gonic/gonic.go index e445df4a..e67e7502 100644 --- a/cmd/gonic/gonic.go +++ b/cmd/gonic/gonic.go @@ -14,6 +14,7 @@ import ( "time" "github.com/google/shlex" + _ "github.com/jinzhu/gorm/dialects/postgres" _ "github.com/jinzhu/gorm/dialects/sqlite" "github.com/oklog/run" "github.com/peterbourgon/ff" @@ -36,7 +37,12 @@ func main() { confTLSKey := set.String("tls-key", "", "path to TLS private key (optional)") confPodcastPath := set.String("podcast-path", "", "path to podcasts") confCachePath := set.String("cache-path", "", "path to cache") - confDBPath := set.String("db-path", "gonic.db", "path to database (optional)") + confSqlitePath := set.String("db-path", "gonic.db", "path to database (optional, default: gonic.db)") + confPostgresHost := set.String("postgres-host", "", "name of the PostgreSQL gonicServer (optional)") + confPostgresPort := set.Int("postgres-port", 5432, "port to use for PostgreSQL connection (optional, default: 5432)") + confPostgresName := set.String("postgres-db", "gonic", "name of the PostgreSQL database (optional, default: gonic)") + confPostgresUser := set.String("postgres-user", "gonic", "name of the PostgreSQL user (optional, default: gonic)") + confPostgresSslModel := set.String("postgres-ssl-mode", "verify-full", "the ssl mode used for connecting to the PostreSQL instance (optional, default: verify-full)") confScanIntervalMins := set.Int("scan-interval", 0, "interval (in minutes) to automatically scan music (optional)") confScanAtStart := set.Bool("scan-at-start-enabled", false, "whether to perform an initial scan at startup (optional)") confScanWatcher := set.Bool("scan-watcher-enabled", false, "whether to watch file system for new music and rescan (optional)") @@ -102,7 +108,13 @@ func main() { } } - dbc, err := db.New(*confDBPath, db.DefaultOptions()) + var dbc *db.DB + var err error + if len(*confPostgresHost) > 0 { + dbc, err = db.NewPostgres(*confPostgresHost, *confPostgresPort, *confPostgresName, *confPostgresUser, os.Getenv("GONIC_POSTGRES_PW"), *confPostgresSslModel) + } else { + dbc, err = db.NewSqlite3(*confSqlitePath, db.DefaultOptions()) + } if err != nil { log.Fatalf("error opening database: %v\n", err) } diff --git a/db/db.go b/db/db.go index 7e4edaf0..41cbaee4 100644 --- a/db/db.go +++ b/db/db.go @@ -34,7 +34,7 @@ type DB struct { *gorm.DB } -func New(path string, options url.Values) (*DB, error) { +func NewSqlite3(path string, options url.Values) (*DB, error) { // https://github.com/mattn/go-sqlite3#connection-string url := url.URL{ Scheme: "file", @@ -45,13 +45,26 @@ func New(path string, options url.Values) (*DB, error) { if err != nil { return nil, fmt.Errorf("with gorm: %w", err) } + return newDB(db) +} + +func NewPostgres(host string, port int, databaseName string, username string, password string, sslmode string) (*DB, error) { + pathAndArgs := fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", host, port, username, databaseName, password, sslmode) + db, err := gorm.Open("postgres", pathAndArgs) + if err != nil { + return nil, fmt.Errorf("with gorm: %w", err) + } + return newDB(db) +} + +func newDB(db *gorm.DB) (*DB, error) { db.SetLogger(log.New(os.Stdout, "gorm ", 0)) db.DB().SetMaxOpenConns(1) return &DB{DB: db}, nil } func NewMock() (*DB, error) { - return New(":memory:", mockOptions()) + return NewSqlite3(":memory:", mockOptions()) } func (db *DB) GetSetting(key string) (string, error) { @@ -80,10 +93,11 @@ func (db *DB) InsertBulkLeftMany(table string, head []string, left int, col []in rows = append(rows, "(?, ?)") values = append(values, left, c) } - q := fmt.Sprintf("INSERT OR IGNORE INTO %q (%s) VALUES %s", + q := fmt.Sprintf("INSERT INTO %q (%s) VALUES %s ON CONFLICT (%s) DO NOTHING", table, strings.Join(head, ", "), strings.Join(rows, ", "), + strings.Join(head, ", "), ) return db.Exec(q, values...).Error } diff --git a/db/migrations.go b/db/migrations.go index a84025b9..83f8497d 100644 --- a/db/migrations.go +++ b/db/migrations.go @@ -73,14 +73,14 @@ func construct(ctx MigrationContext, id string, f func(*gorm.DB, MigrationContex func migrateInitSchema(tx *gorm.DB, _ MigrationContext) error { return tx.AutoMigrate( Genre{}, + Artist{}, + Album{}, + Track{}, TrackGenre{}, AlbumGenre{}, - Track{}, - Artist{}, User{}, Setting{}, Play{}, - Album{}, Playlist{}, PlayQueue{}, ). @@ -144,12 +144,18 @@ func migrateAddGenre(tx *gorm.DB, _ MigrationContext) error { func migrateUpdateTranscodePrefIDX(tx *gorm.DB, _ MigrationContext) error { var hasIDX int - tx. - Select("1"). - Table("sqlite_master"). - Where("type = ?", "index"). - Where("name = ?", "idx_user_id_client"). - Count(&hasIDX) + if tx.Dialect().GetName() == "sqlite3" { + tx.Select("1"). + Table("sqlite_master"). + Where("type = ?", "index"). + Where("name = ?", "idx_user_id_client"). + Count(&hasIDX) + } else if tx.Dialect().GetName() == "postgres" { + tx.Select("1"). + Table("pg_indexes"). + Where("indexname = ?", "idx_user_id_client"). + Count(&hasIDX) + } if hasIDX == 1 { // index already exists return nil diff --git a/scanner/scanner_test.go b/scanner/scanner_test.go index 39dc135e..434f0f7e 100644 --- a/scanner/scanner_test.go +++ b/scanner/scanner_test.go @@ -397,7 +397,7 @@ func TestMultiFolderWithSharedArtist(t *testing.T) { sq := func(db *gorm.DB) *gorm.DB { return db. - Select("*, count(sub.id) child_count, sum(sub.length) duration"). + Select("albums.*, count(sub.id) child_count, sum(sub.length) duration"). Joins("LEFT JOIN tracks sub ON albums.id=sub.album_id"). Group("albums.id") } diff --git a/server/ctrlsubsonic/handlers_by_folder.go b/server/ctrlsubsonic/handlers_by_folder.go index b23cfb28..b866a73a 100644 --- a/server/ctrlsubsonic/handlers_by_folder.go +++ b/server/ctrlsubsonic/handlers_by_folder.go @@ -31,13 +31,13 @@ func (c *Controller) ServeGetIndexes(r *http.Request) *spec.Response { } var folders []*db.Album c.DB. - Select("*, count(sub.id) child_count"). + Select("albums.*, count(sub.id) child_count"). Preload("AlbumStar", "user_id=?", user.ID). Preload("AlbumRating", "user_id=?", user.ID). Joins("LEFT JOIN albums sub ON albums.id=sub.parent_id"). Where("albums.parent_id IN ?", rootQ.SubQuery()). Group("albums.id"). - Order("albums.right_path COLLATE NOCASE"). + Order("albums.right_path"). Find(&folders) // [a-z#] -> 27 indexMap := make(map[string]*spec.Index, 27) @@ -80,7 +80,7 @@ func (c *Controller) ServeGetMusicDirectory(r *http.Request) *spec.Response { Where("parent_id=?", id.Value). Preload("AlbumStar", "user_id=?", user.ID). Preload("AlbumRating", "user_id=?", user.ID). - Order("albums.right_path COLLATE NOCASE"). + Order("albums.right_path"). Find(&childFolders) for _, ch := range childFolders { childrenObj = append(childrenObj, spec.NewTCAlbumByFolder(ch)) diff --git a/server/ctrlsubsonic/handlers_by_tags.go b/server/ctrlsubsonic/handlers_by_tags.go index f36a2c16..b946ed41 100644 --- a/server/ctrlsubsonic/handlers_by_tags.go +++ b/server/ctrlsubsonic/handlers_by_tags.go @@ -24,12 +24,12 @@ func (c *Controller) ServeGetArtists(r *http.Request) *spec.Response { user := r.Context().Value(CtxUser).(*db.User) var artists []*db.Artist q := c.DB. - Select("*, count(sub.id) album_count"). + Select("artists.*, count(sub.id) album_count"). Joins("LEFT JOIN albums sub ON artists.id=sub.tag_artist_id"). Preload("ArtistStar", "user_id=?", user.ID). Preload("ArtistRating", "user_id=?", user.ID). Group("artists.id"). - Order("artists.name COLLATE NOCASE") + Order("artists.name") if m := getMusicFolder(c.MusicPaths, params); m != "" { q = q.Where("sub.root_dir=?", m) } @@ -68,7 +68,7 @@ func (c *Controller) ServeGetArtist(r *http.Request) *spec.Response { c.DB. Preload("Albums", func(db *gorm.DB) *gorm.DB { return db. - Select("*, count(sub.id) child_count, sum(sub.length) duration"). + Select("albums.*, count(sub.id) child_count, sum(sub.length) duration"). Joins("LEFT JOIN tracks sub ON albums.id=sub.album_id"). Preload("AlbumStar", "user_id=?", user.ID). Preload("AlbumRating", "user_id=?", user.ID). @@ -99,6 +99,7 @@ func (c *Controller) ServeGetAlbum(r *http.Request) *spec.Response { err = c.DB. Select("albums.*, count(tracks.id) child_count, sum(tracks.length) duration"). Joins("LEFT JOIN tracks ON tracks.album_id=albums.id"). + Group("albums.id"). Preload("TagArtist"). Preload("Genres"). Preload("Tracks", func(db *gorm.DB) *gorm.DB { @@ -163,14 +164,14 @@ func (c *Controller) ServeGetAlbumListTwo(r *http.Request) *spec.Response { case "frequent": user := r.Context().Value(CtxUser).(*db.User) q = q.Joins("JOIN plays ON albums.id=plays.album_id AND plays.user_id=?", user.ID) - q = q.Order("plays.count DESC") + q = q.Order("SUM(plays.count) DESC") case "newest": q = q.Order("created_at DESC") case "random": q = q.Order(gorm.Expr("random()")) case "recent": q = q.Joins("JOIN plays ON albums.id=plays.album_id AND plays.user_id=?", user.ID) - q = q.Order("plays.time DESC") + q = q.Order("MAX(plays.time) DESC") case "starred": q = q.Joins("JOIN album_stars ON albums.id=album_stars.album_id AND album_stars.user_id=?", user.ID) q = q.Order("tag_title") @@ -218,7 +219,7 @@ func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response { // search artists var artists []*db.Artist q := c.DB. - Select("*, count(albums.id) album_count"). + Select("artists.*, count(albums.id) album_count"). Group("artists.id"). Where("name LIKE ? OR name_u_dec LIKE ?", query, query). Joins("JOIN albums ON albums.tag_artist_id=artists.id"). diff --git a/server/ctrlsubsonic/handlers_raw.go b/server/ctrlsubsonic/handlers_raw.go index 16b89a8d..cf6eb77e 100644 --- a/server/ctrlsubsonic/handlers_raw.go +++ b/server/ctrlsubsonic/handlers_raw.go @@ -30,7 +30,7 @@ func streamGetTransPref(dbc *db.DB, userID int, client string) (*db.TranscodePre var pref db.TranscodePreference err := dbc. Where("user_id=?", userID). - Where("client COLLATE NOCASE IN (?)", []string{"*", client}). + Where("client IN (?)", []string{"*", client}). Order("client DESC"). // ensure "*" is last if it's there First(&pref). Error diff --git a/server/server.go b/server/server.go index 730739a1..99c799e2 100644 --- a/server/server.go +++ b/server/server.go @@ -1,6 +1,7 @@ package server import ( + "encoding/base64" "fmt" "log" "net/http" @@ -70,17 +71,19 @@ func New(opts Options) (*Server, error) { } r.Use(base.WithCORS) - sessKey, err := opts.DB.GetSetting("session_key") + encSessKey, err := opts.DB.GetSetting("session_key") if err != nil { return nil, fmt.Errorf("get session key: %w", err) } - if sessKey == "" { - if err := opts.DB.SetSetting("session_key", string(securecookie.GenerateRandomKey(32))); err != nil { + sessKey, err := base64.StdEncoding.DecodeString(encSessKey) + if err != nil || len(sessKey) == 0 { + sessKey = securecookie.GenerateRandomKey(32) + if err := opts.DB.SetSetting("session_key", base64.StdEncoding.EncodeToString(sessKey)); err != nil { return nil, fmt.Errorf("set session key: %w", err) } } - sessDB := gormstore.New(opts.DB.DB, []byte(sessKey)) + sessDB := gormstore.New(opts.DB.DB, sessKey) sessDB.SessionOpts.HttpOnly = true sessDB.SessionOpts.SameSite = http.SameSiteLaxMode