-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #50 from elisasre/auth
add auth store
- Loading branch information
Showing
14 changed files
with
458 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package auth | ||
|
||
import ( | ||
"fmt" | ||
|
||
"github.com/elisasre/go-common" | ||
database "github.com/elisasre/go-common/auth/db" | ||
"github.com/elisasre/go-common/auth/memory" | ||
|
||
"github.com/rs/zerolog/log" | ||
) | ||
|
||
// AuthInterface will contain interface to interact with different auth providers. | ||
type AuthInterface interface { | ||
GetKeys() []common.JWTKey | ||
GetCurrentKey() common.JWTKey | ||
RotateKeys() error | ||
RefreshKeys(bool) ([]common.JWTKey, error) | ||
} | ||
|
||
func AuthProvider(mode string, store common.Datastore) (AuthInterface, error) { | ||
log.Info().Str("mode", mode).Msg("Using AuthProvider") | ||
switch mode { | ||
case "memory": | ||
return memory.NewMemory() | ||
case "database": | ||
return database.NewDatabase(store) | ||
default: | ||
return nil, fmt.Errorf("unknown auth mode '%s'", mode) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
package database | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"sync" | ||
"time" | ||
|
||
"github.com/elisasre/go-common" | ||
|
||
"github.com/rs/zerolog/log" | ||
) | ||
|
||
// Database is an implementation of Interface for database auth. | ||
type Database struct { | ||
keys []common.JWTKey | ||
store common.Datastore | ||
keysMu sync.RWMutex | ||
} | ||
|
||
// NewDatabase init new database interface. | ||
func NewDatabase(store common.Datastore) (*Database, error) { | ||
db := &Database{ | ||
store: store, | ||
} | ||
keys, err := db.store.ListJWTKeys(context.Background()) | ||
if err != nil { | ||
return nil, fmt.Errorf("error ListKeys: %w", err) | ||
} | ||
if len(keys) > 0 { | ||
db.keysMu.Lock() | ||
defer db.keysMu.Unlock() | ||
db.keys = keys | ||
log.Info().Strs("keys", getKIDs(db.keys)).Msg("JWT keys loaded from database") | ||
return db, nil | ||
} | ||
if err := db.RotateKeys(); err != nil { | ||
return nil, err | ||
} | ||
return db, nil | ||
} | ||
|
||
func getKIDs(keys []common.JWTKey) []string { | ||
ids := make([]string, 0, len(keys)) | ||
for _, k := range keys { | ||
ids = append(ids, k.KID) | ||
} | ||
return ids | ||
} | ||
|
||
// RotateKeys rotates the jwt secrets. | ||
func (db *Database) RotateKeys() error { | ||
db.keysMu.Lock() | ||
defer db.keysMu.Unlock() | ||
start := time.Now() | ||
keys, err := common.GenerateNewKeyPair() | ||
if err != nil { | ||
return fmt.Errorf("error GenerateNewKeyPair: %w", err) | ||
} | ||
|
||
newest, err := db.store.AddJWTKey(context.Background(), *keys) | ||
if err != nil { | ||
return fmt.Errorf("error AddKeys: %w", err) | ||
} | ||
|
||
err = db.store.RotateJWTKeys(context.Background(), newest.ID) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
newKeys, err := db.refreshKeys(false) | ||
if err != nil { | ||
return err | ||
} | ||
db.keys = newKeys | ||
log.Info(). | ||
Strs("keys", getKIDs(db.keys)). | ||
Str("duration", time.Since(start).String()). | ||
Msg("JWT RotateKeys called") | ||
return nil | ||
} | ||
|
||
func (db *Database) refreshKeys(reload bool) ([]common.JWTKey, error) { | ||
keys, err := db.store.ListJWTKeys(context.Background()) | ||
if err != nil { | ||
return keys, fmt.Errorf("error ListKeys: %w", err) | ||
} | ||
if reload { | ||
db.keys = keys | ||
log.Info(). | ||
Strs("keys", getKIDs(db.keys)). | ||
Msg("JWT RefreshKeys called") | ||
} | ||
return keys, nil | ||
} | ||
|
||
// RefreshKeys refresh the keys from database. | ||
func (db *Database) RefreshKeys(reload bool) ([]common.JWTKey, error) { | ||
db.keysMu.Lock() | ||
defer db.keysMu.Unlock() | ||
return db.refreshKeys(reload) | ||
} | ||
|
||
// GetKeys fetch all keys from cache. | ||
func (db *Database) GetKeys() []common.JWTKey { | ||
db.keysMu.RLock() | ||
defer db.keysMu.RUnlock() | ||
data := make([]common.JWTKey, len(db.keys)) | ||
copy(data, db.keys) | ||
return data | ||
} | ||
|
||
// GetCurrentKey fetch latest key from cache, it should have privatekey. | ||
func (db *Database) GetCurrentKey() common.JWTKey { | ||
db.keysMu.RLock() | ||
defer db.keysMu.RUnlock() | ||
return db.keys[0] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
package database | ||
|
||
import ( | ||
"context" | ||
"math/rand" | ||
"testing" | ||
"time" | ||
|
||
"github.com/elisasre/go-common" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
type DB struct { | ||
keys []common.JWTKey | ||
} | ||
|
||
func (store *DB) AddJWTKey(c context.Context, payload common.JWTKey) (*common.JWTKey, error) { | ||
id := rand.Intn(1000000) | ||
payload.Model.ID = uint(id) | ||
payload.Model.CreatedAt = time.Now() | ||
payload.Model.UpdatedAt = time.Now() | ||
store.keys = append(store.keys, payload) | ||
return &payload, nil | ||
} | ||
|
||
func (store *DB) ListJWTKeys(c context.Context) ([]common.JWTKey, error) { | ||
return store.keys, nil | ||
} | ||
|
||
func (store *DB) RotateJWTKeys(c context.Context, idx uint) error { | ||
out := []common.JWTKey{} | ||
for _, key := range store.keys { | ||
if key.ID != idx { | ||
key.PrivateKey = nil | ||
key.PrivateKeyAsBytes = nil | ||
} | ||
out = append(out, key) | ||
} | ||
// keep 3 latest ones | ||
if len(out) > 3 { | ||
store.keys = []common.JWTKey{out[len(out)-3], out[len(out)-2], out[len(out)-1]} | ||
} else { | ||
store.keys = out | ||
} | ||
return nil | ||
} | ||
|
||
func TestRotateKeys(t *testing.T) { | ||
store := &DB{} | ||
db, err := NewDatabase(store) | ||
require.NoError(t, err) | ||
require.Equal(t, 1, len(db.keys)) | ||
|
||
db2, err := NewDatabase(store) | ||
require.NoError(t, err) | ||
require.Equal(t, 1, len(db2.keys)) | ||
require.Equal(t, db.keys, db2.keys) | ||
|
||
err = db.RotateKeys() | ||
require.NoError(t, err) | ||
require.Equal(t, 2, len(db.keys)) | ||
err = db.RotateKeys() | ||
require.NoError(t, err) | ||
require.Equal(t, 3, len(db.keys)) | ||
err = db.RotateKeys() | ||
require.NoError(t, err) | ||
require.Equal(t, 3, len(db.keys)) | ||
|
||
ids := make([]string, 0, len(db.keys)) | ||
for _, key := range db.keys { | ||
ids = append(ids, key.KID) | ||
} | ||
// rotate all keys | ||
for i := 0; i < 5; i++ { | ||
err = db.RotateKeys() | ||
require.NoError(t, err) | ||
} | ||
|
||
// check that mem.keys does not exist in ids | ||
for _, key := range db.keys { | ||
require.NotContains(t, ids, key.KID) | ||
} | ||
require.Equal(t, 3, len(db.keys)) | ||
} | ||
|
||
func TestGetAndRefresh(t *testing.T) { | ||
store := &DB{} | ||
db, err := NewDatabase(store) | ||
require.NoError(t, err) | ||
require.Equal(t, 1, len(db.keys)) | ||
require.Equal(t, db.keys[0], db.GetCurrentKey()) | ||
require.Equal(t, db.keys, db.GetKeys()) | ||
data, err := db.RefreshKeys(true) | ||
require.NoError(t, err) | ||
require.Equal(t, db.keys, data) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
package memory | ||
|
||
import ( | ||
"sync" | ||
"time" | ||
|
||
"github.com/elisasre/go-common" | ||
|
||
"github.com/rs/zerolog/log" | ||
) | ||
|
||
// Memory is an implementation of Interface for memory auth. | ||
type Memory struct { | ||
keys []common.JWTKey | ||
keysMu sync.RWMutex | ||
} | ||
|
||
// NewMemory init new memory interface. | ||
// Memory is used mainly for testing do NOT use in production. | ||
func NewMemory() (*Memory, error) { | ||
m := &Memory{} | ||
err := m.RotateKeys() | ||
if err != nil { | ||
return nil, err | ||
} | ||
return m, nil | ||
} | ||
|
||
// RotateKeys rotates the jwt secrets. | ||
func (m *Memory) RotateKeys() error { | ||
m.keysMu.Lock() | ||
defer m.keysMu.Unlock() | ||
start := time.Now() | ||
keys, err := common.GenerateNewKeyPair() | ||
if err != nil { | ||
return err | ||
} | ||
// private key is needed only in newest which are used to generate new tokens | ||
for i := range m.keys { | ||
m.keys[i].PrivateKey = nil | ||
m.keys[i].PrivateKeyAsBytes = nil | ||
} | ||
m.keys = append([]common.JWTKey{*keys}, m.keys...) | ||
|
||
// keep 3 latest public keys | ||
if len(m.keys) > 3 { | ||
m.keys = m.keys[0:3] | ||
} | ||
log.Info(). | ||
Str("duration", time.Since(start).String()). | ||
Msg("rotate keys") | ||
return nil | ||
} | ||
|
||
// GetKeys fetch all keys from cache. | ||
func (m *Memory) GetKeys() []common.JWTKey { | ||
m.keysMu.RLock() | ||
defer m.keysMu.RUnlock() | ||
data := make([]common.JWTKey, len(m.keys)) | ||
copy(data, m.keys) | ||
return data | ||
} | ||
|
||
// GetCurrentKey fetch latest key from cache, it should have privatekey. | ||
func (m *Memory) GetCurrentKey() common.JWTKey { | ||
m.keysMu.RLock() | ||
defer m.keysMu.RUnlock() | ||
return m.keys[0] | ||
} | ||
|
||
// RefreshKeys refresh the keys from database. | ||
func (m *Memory) RefreshKeys(reload bool) ([]common.JWTKey, error) { | ||
return m.keys, nil | ||
} |
Oops, something went wrong.