Skip to content

Commit

Permalink
Adding support for PostgreSQL as database
Browse files Browse the repository at this point in the history
This adds support for a second database backend: PostgreSQL (in addition to sqlite3). This allows externailzing the database used by gonic.
  • Loading branch information
02strich committed Feb 26, 2023
1 parent 16e6046 commit 0031a94
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 31 deletions.
16 changes: 14 additions & 2 deletions cmd/gonic/gonic.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"time"

"github.com/google/shlex"
_ "github.com/jinzhu/gorm/dialects/postgres"
_ "github.com/jinzhu/gorm/dialects/sqlite"
"github.com/oklog/run"
"github.com/peterbourgon/ff"
Expand All @@ -38,7 +39,12 @@ func main() {
confTLSKey := set.String("tls-key", "", "path to TLS private key (optional)")
confPodcastPath := set.String("podcast-path", "", "path to podcasts")
confCachePath := set.String("cache-path", "", "path to cache")
confDBPath := set.String("db-path", "gonic.db", "path to database (optional)")
confSqlitePath := set.String("db-path", "gonic.db", "path to database (optional, default: gonic.db)")
confPostgresHost := set.String("postgres-host", "", "name of the PostgreSQL gonicServer (optional)")
confPostgresPort := set.Int("postgres-port", 5432, "port to use for PostgreSQL connection (optional, default: 5432)")
confPostgresName := set.String("postgres-db", "gonic", "name of the PostgreSQL database (optional, default: gonic)")
confPostgresUser := set.String("postgres-user", "gonic", "name of the PostgreSQL user (optional, default: gonic)")
confPostgresSslModel := set.String("postgres-ssl-mode", "verify-full", "the ssl mode used for connecting to the PostreSQL instance (optional, default: verify-full)")
confScanIntervalMins := set.Int("scan-interval", 0, "interval (in minutes) to automatically scan music (optional)")
confScanAtStart := set.Bool("scan-at-start-enabled", false, "whether to perform an initial scan at startup (optional)")
confScanWatcher := set.Bool("scan-watcher-enabled", false, "whether to watch file system for new music and rescan (optional)")
Expand Down Expand Up @@ -104,7 +110,13 @@ func main() {
}
}

dbc, err := db.New(*confDBPath, db.DefaultOptions())
var dbc *db.DB
var err error
if len(*confPostgresHost) > 0 {
dbc, err = db.NewPostgres(*confPostgresHost, *confPostgresPort, *confPostgresName, *confPostgresUser, os.Getenv("GONIC_POSTGRES_PW"), *confPostgresSslModel)
} else {
dbc, err = db.NewSqlite3(*confSqlitePath, db.DefaultOptions())
}
if err != nil {
log.Fatalf("error opening database: %v\n", err)
}
Expand Down
20 changes: 17 additions & 3 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type DB struct {
*gorm.DB
}

func New(path string, options url.Values) (*DB, error) {
func NewSqlite3(path string, options url.Values) (*DB, error) {
// https://github.com/mattn/go-sqlite3#connection-string
url := url.URL{
Scheme: "file",
Expand All @@ -45,13 +45,26 @@ func New(path string, options url.Values) (*DB, error) {
if err != nil {
return nil, fmt.Errorf("with gorm: %w", err)
}
return newDB(db)
}

func NewPostgres(host string, port int, databaseName string, username string, password string, sslmode string) (*DB, error) {
pathAndArgs := fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", host, port, username, databaseName, password, sslmode)
db, err := gorm.Open("postgres", pathAndArgs)
if err != nil {
return nil, fmt.Errorf("with gorm: %w", err)
}
return newDB(db)
}

func newDB(db *gorm.DB) (*DB, error) {
db.SetLogger(log.New(os.Stdout, "gorm ", 0))
db.DB().SetMaxOpenConns(1)
return &DB{DB: db}, nil
}

func NewMock() (*DB, error) {
return New(":memory:", mockOptions())
return NewSqlite3(":memory:", mockOptions())
}

func (db *DB) GetSetting(key string) (string, error) {
Expand Down Expand Up @@ -80,10 +93,11 @@ func (db *DB) InsertBulkLeftMany(table string, head []string, left int, col []in
rows = append(rows, "(?, ?)")
values = append(values, left, c)
}
q := fmt.Sprintf("INSERT OR IGNORE INTO %q (%s) VALUES %s",
q := fmt.Sprintf("INSERT INTO %q (%s) VALUES %s ON CONFLICT (%s) DO NOTHING",
table,
strings.Join(head, ", "),
strings.Join(rows, ", "),
strings.Join(head, ", "),
)
return db.Exec(q, values...).Error
}
Expand Down
32 changes: 22 additions & 10 deletions db/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ func construct(ctx MigrationContext, id string, f func(*gorm.DB, MigrationContex
func migrateInitSchema(tx *gorm.DB, _ MigrationContext) error {
return tx.AutoMigrate(
Genre{},
Artist{},
Album{},
Track{},
TrackGenre{},
AlbumGenre{},
Track{},
Artist{},
User{},
Setting{},
Play{},
Album{},
Playlist{},
PlayQueue{},
).
Expand Down Expand Up @@ -145,12 +145,18 @@ func migrateAddGenre(tx *gorm.DB, _ MigrationContext) error {

func migrateUpdateTranscodePrefIDX(tx *gorm.DB, _ MigrationContext) error {
var hasIDX int
tx.
Select("1").
Table("sqlite_master").
Where("type = ?", "index").
Where("name = ?", "idx_user_id_client").
Count(&hasIDX)
if tx.Dialect().GetName() == "sqlite3" {
tx.Select("1").
Table("sqlite_master").
Where("type = ?", "index").
Where("name = ?", "idx_user_id_client").
Count(&hasIDX)
} else if tx.Dialect().GetName() == "postgres" {
tx.Select("1").
Table("pg_indexes").
Where("indexname = ?", "idx_user_id_client").
Count(&hasIDX)
}
if hasIDX == 1 {
// index already exists
return nil
Expand Down Expand Up @@ -420,9 +426,15 @@ func migratePlaylistsQueuesToFullID(tx *gorm.DB, _ MigrationContext) error {
if err := step.Error; err != nil {
return fmt.Errorf("step migrate play_queues to full id: %w", err)
}
step = tx.Exec(`
if tx.Dialect().GetName() == "postgres" {
step = tx.Exec(`
UPDATE play_queues SET newcurrent=('tr-' || current)::varchar[200];
`)
} else {
step = tx.Exec(`
UPDATE play_queues SET newcurrent=('tr-' || CAST(current AS varchar(10)));
`)
}
if err := step.Error; err != nil {
return fmt.Errorf("step migrate play_queues to full id: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion mockfs/mockfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ func (m *MockFS) DumpDB(suffix ...string) {
p = append(p, suffix...)

destPath := filepath.Join(os.TempDir(), strings.Join(p, "-"))
dest, err := db.New(destPath, url.Values{})
dest, err := db.NewSqlite3(destPath, url.Values{})
if err != nil {
m.t.Fatalf("create dest db: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion scanner/scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ func TestMultiFolderWithSharedArtist(t *testing.T) {

sq := func(db *gorm.DB) *gorm.DB {
return db.
Select("*, count(sub.id) child_count, sum(sub.length) duration").
Select("albums.*, count(sub.id) child_count, sum(sub.length) duration").
Joins("LEFT JOIN tracks sub ON albums.id=sub.album_id").
Group("albums.id")
}
Expand Down
6 changes: 3 additions & 3 deletions server/ctrlsubsonic/handlers_by_folder.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ func (c *Controller) ServeGetIndexes(r *http.Request) *spec.Response {
}
var folders []*db.Album
c.DB.
Select("*, count(sub.id) child_count").
Select("albums.*, count(sub.id) child_count").
Preload("AlbumStar", "user_id=?", user.ID).
Preload("AlbumRating", "user_id=?", user.ID).
Joins("LEFT JOIN albums sub ON albums.id=sub.parent_id").
Where("albums.parent_id IN ?", rootQ.SubQuery()).
Group("albums.id").
Order("albums.right_path COLLATE NOCASE").
Order("albums.right_path").
Find(&folders)
// [a-z#] -> 27
indexMap := make(map[string]*spec.Index, 27)
Expand Down Expand Up @@ -80,7 +80,7 @@ func (c *Controller) ServeGetMusicDirectory(r *http.Request) *spec.Response {
Where("parent_id=?", id.Value).
Preload("AlbumStar", "user_id=?", user.ID).
Preload("AlbumRating", "user_id=?", user.ID).
Order("albums.right_path COLLATE NOCASE").
Order("albums.right_path").
Find(&childFolders)
for _, ch := range childFolders {
childrenObj = append(childrenObj, spec.NewTCAlbumByFolder(ch))
Expand Down
13 changes: 7 additions & 6 deletions server/ctrlsubsonic/handlers_by_tags.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ func (c *Controller) ServeGetArtists(r *http.Request) *spec.Response {
user := r.Context().Value(CtxUser).(*db.User)
var artists []*db.Artist
q := c.DB.
Select("*, count(sub.id) album_count").
Select("artists.*, count(sub.id) album_count").
Joins("LEFT JOIN albums sub ON artists.id=sub.tag_artist_id").
Preload("ArtistStar", "user_id=?", user.ID).
Preload("ArtistRating", "user_id=?", user.ID).
Group("artists.id").
Order("artists.name COLLATE NOCASE")
Order("artists.name")
if m := getMusicFolder(c.MusicPaths, params); m != "" {
q = q.Where("sub.root_dir=?", m)
}
Expand Down Expand Up @@ -68,7 +68,7 @@ func (c *Controller) ServeGetArtist(r *http.Request) *spec.Response {
c.DB.
Preload("Albums", func(db *gorm.DB) *gorm.DB {
return db.
Select("*, count(sub.id) child_count, sum(sub.length) duration").
Select("albums.*, count(sub.id) child_count, sum(sub.length) duration").
Joins("LEFT JOIN tracks sub ON albums.id=sub.album_id").
Preload("AlbumStar", "user_id=?", user.ID).
Preload("AlbumRating", "user_id=?", user.ID).
Expand Down Expand Up @@ -99,6 +99,7 @@ func (c *Controller) ServeGetAlbum(r *http.Request) *spec.Response {
err = c.DB.
Select("albums.*, count(tracks.id) child_count, sum(tracks.length) duration").
Joins("LEFT JOIN tracks ON tracks.album_id=albums.id").
Group("albums.id").
Preload("TagArtist").
Preload("Genres").
Preload("Tracks", func(db *gorm.DB) *gorm.DB {
Expand Down Expand Up @@ -163,14 +164,14 @@ func (c *Controller) ServeGetAlbumListTwo(r *http.Request) *spec.Response {
case "frequent":
user := r.Context().Value(CtxUser).(*db.User)
q = q.Joins("JOIN plays ON albums.id=plays.album_id AND plays.user_id=?", user.ID)
q = q.Order("plays.count DESC")
q = q.Order("SUM(plays.count) DESC")
case "newest":
q = q.Order("created_at DESC")
case "random":
q = q.Order(gorm.Expr("random()"))
case "recent":
q = q.Joins("JOIN plays ON albums.id=plays.album_id AND plays.user_id=?", user.ID)
q = q.Order("plays.time DESC")
q = q.Order("MAX(plays.time) DESC")
case "starred":
q = q.Joins("JOIN album_stars ON albums.id=album_stars.album_id AND album_stars.user_id=?", user.ID)
q = q.Order("tag_title")
Expand Down Expand Up @@ -218,7 +219,7 @@ func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response {
// search artists
var artists []*db.Artist
q := c.DB.
Select("*, count(albums.id) album_count").
Select("artists.*, count(albums.id) album_count").
Group("artists.id").
Where("name LIKE ? OR name_u_dec LIKE ?", query, query).
Joins("JOIN albums ON albums.tag_artist_id=artists.id").
Expand Down
2 changes: 1 addition & 1 deletion server/ctrlsubsonic/handlers_raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func streamGetTransPref(dbc *db.DB, userID int, client string) (*db.TranscodePre
var pref db.TranscodePreference
err := dbc.
Where("user_id=?", userID).
Where("client COLLATE NOCASE IN (?)", []string{"*", client}).
Where("client IN (?)", []string{"*", client}).
Order("client DESC"). // ensure "*" is last if it's there
First(&pref).
Error
Expand Down
11 changes: 7 additions & 4 deletions server/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"encoding/base64"
"fmt"
"log"
"net/http"
Expand Down Expand Up @@ -67,17 +68,19 @@ func New(opts Options) (*Server, error) {
r.Use(base.WithCORS)
r.Use(handlers.RecoveryHandler(handlers.PrintRecoveryStack(true)))

sessKey, err := opts.DB.GetSetting("session_key")
encSessKey, err := opts.DB.GetSetting("session_key")
if err != nil {
return nil, fmt.Errorf("get session key: %w", err)
}
if sessKey == "" {
if err := opts.DB.SetSetting("session_key", string(securecookie.GenerateRandomKey(32))); err != nil {
sessKey, err := base64.StdEncoding.DecodeString(encSessKey)
if err != nil || len(sessKey) == 0 {
sessKey = securecookie.GenerateRandomKey(32)
if err := opts.DB.SetSetting("session_key", base64.StdEncoding.EncodeToString(sessKey)); err != nil {
return nil, fmt.Errorf("set session key: %w", err)
}
}

sessDB := gormstore.New(opts.DB.DB, []byte(sessKey))
sessDB := gormstore.New(opts.DB.DB, sessKey)
sessDB.SessionOpts.HttpOnly = true
sessDB.SessionOpts.SameSite = http.SameSiteLaxMode

Expand Down

0 comments on commit 0031a94

Please sign in to comment.