Skip to content

Commit

Permalink
Merge pull request #50 from elisasre/auth
Browse files Browse the repository at this point in the history
add auth store
  • Loading branch information
zetaab authored May 10, 2023
2 parents e156500 + c3bff32 commit 0496e26
Show file tree
Hide file tree
Showing 14 changed files with 458 additions and 10 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:

- uses: actions/setup-go@v4
with:
go-version: 1.19
go-version: '1.20'

- name: Ensure
run: |
Expand All @@ -31,5 +31,5 @@ jobs:
- name: Lint
uses: golangci/golangci-lint-action@v3
with:
version: v1.50
version: v1.51.2
skip-cache: true
4 changes: 2 additions & 2 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:

- uses: actions/setup-go@v4
with:
go-version: 1.19
go-version: '1.20'

- name: Ensure
run: |
Expand All @@ -33,7 +33,7 @@ jobs:
- name: Lint
uses: golangci/golangci-lint-action@v3
with:
version: v1.50
version: v1.51.2
skip-cache: true

automerge:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tag.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:

- uses: actions/setup-go@v4
with:
go-version: 1.19
go-version: '1.20'

- name: publish package
run: |
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ exclude: vendor
fail_fast: true
repos:
- repo: https://github.com/golangci/golangci-lint
rev: v1.51.1
rev: v1.51.2
hooks:
- id: golangci-lint
args: [ --fix ]
5 changes: 2 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@ clean:

ensure:
go mod tidy
go mod vendor

build:
rm -f bin/$(OPERATOR_NAME)
go build -mod vendor -v -o bin/$(OPERATOR_NAME) .
go build -v -o bin/$(OPERATOR_NAME) .

golint: .git/hooks/pre-commit
pre-commit run --all-files

test:
go test -failfast -mod vendor ./*.go -v -covermode atomic -coverprofile=gotest-coverage.out $(GOTEST_REPORT_FORMAT) > gotest-report.out && cat gotest-report.out || (cat gotest-report.out; exit 1)
go test -race -covermode atomic -coverprofile=gotest-coverage.out ./... $(GOTEST_REPORT_FORMAT) > gotest-report.out && cat gotest-report.out || (cat gotest-report.out; exit 1)
git diff --exit-code go.mod go.sum


Expand Down
31 changes: 31 additions & 0 deletions auth/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package auth

import (
"fmt"

"github.com/elisasre/go-common"
database "github.com/elisasre/go-common/auth/db"
"github.com/elisasre/go-common/auth/memory"

"github.com/rs/zerolog/log"
)

// AuthInterface will contain interface to interact with different auth providers.
type AuthInterface interface {
GetKeys() []common.JWTKey
GetCurrentKey() common.JWTKey
RotateKeys() error
RefreshKeys(bool) ([]common.JWTKey, error)
}

func AuthProvider(mode string, store common.Datastore) (AuthInterface, error) {
log.Info().Str("mode", mode).Msg("Using AuthProvider")
switch mode {
case "memory":
return memory.NewMemory()
case "database":
return database.NewDatabase(store)
default:
return nil, fmt.Errorf("unknown auth mode '%s'", mode)
}
}
118 changes: 118 additions & 0 deletions auth/db/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package database

import (
"context"
"fmt"
"sync"
"time"

"github.com/elisasre/go-common"

"github.com/rs/zerolog/log"
)

// Database is an implementation of Interface for database auth.
type Database struct {
keys []common.JWTKey
store common.Datastore
keysMu sync.RWMutex
}

// NewDatabase init new database interface.
func NewDatabase(store common.Datastore) (*Database, error) {
db := &Database{
store: store,
}
keys, err := db.store.ListJWTKeys(context.Background())
if err != nil {
return nil, fmt.Errorf("error ListKeys: %w", err)
}
if len(keys) > 0 {
db.keysMu.Lock()
defer db.keysMu.Unlock()
db.keys = keys
log.Info().Strs("keys", getKIDs(db.keys)).Msg("JWT keys loaded from database")
return db, nil
}
if err := db.RotateKeys(); err != nil {
return nil, err
}
return db, nil
}

func getKIDs(keys []common.JWTKey) []string {
ids := make([]string, 0, len(keys))
for _, k := range keys {
ids = append(ids, k.KID)
}
return ids
}

// RotateKeys rotates the jwt secrets.
func (db *Database) RotateKeys() error {
db.keysMu.Lock()
defer db.keysMu.Unlock()
start := time.Now()
keys, err := common.GenerateNewKeyPair()
if err != nil {
return fmt.Errorf("error GenerateNewKeyPair: %w", err)
}

newest, err := db.store.AddJWTKey(context.Background(), *keys)
if err != nil {
return fmt.Errorf("error AddKeys: %w", err)
}

err = db.store.RotateJWTKeys(context.Background(), newest.ID)
if err != nil {
return err
}

newKeys, err := db.refreshKeys(false)
if err != nil {
return err
}
db.keys = newKeys
log.Info().
Strs("keys", getKIDs(db.keys)).
Str("duration", time.Since(start).String()).
Msg("JWT RotateKeys called")
return nil
}

func (db *Database) refreshKeys(reload bool) ([]common.JWTKey, error) {
keys, err := db.store.ListJWTKeys(context.Background())
if err != nil {
return keys, fmt.Errorf("error ListKeys: %w", err)
}
if reload {
db.keys = keys
log.Info().
Strs("keys", getKIDs(db.keys)).
Msg("JWT RefreshKeys called")
}
return keys, nil
}

// RefreshKeys refresh the keys from database.
func (db *Database) RefreshKeys(reload bool) ([]common.JWTKey, error) {
db.keysMu.Lock()
defer db.keysMu.Unlock()
return db.refreshKeys(reload)
}

// GetKeys fetch all keys from cache.
func (db *Database) GetKeys() []common.JWTKey {
db.keysMu.RLock()
defer db.keysMu.RUnlock()
data := make([]common.JWTKey, len(db.keys))
copy(data, db.keys)
return data
}

// GetCurrentKey fetch latest key from cache, it should have privatekey.
func (db *Database) GetCurrentKey() common.JWTKey {
db.keysMu.RLock()
defer db.keysMu.RUnlock()
return db.keys[0]
}
96 changes: 96 additions & 0 deletions auth/db/db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package database

import (
"context"
"math/rand"
"testing"
"time"

"github.com/elisasre/go-common"
"github.com/stretchr/testify/require"
)

type DB struct {
keys []common.JWTKey
}

func (store *DB) AddJWTKey(c context.Context, payload common.JWTKey) (*common.JWTKey, error) {
id := rand.Intn(1000000)
payload.Model.ID = uint(id)
payload.Model.CreatedAt = time.Now()
payload.Model.UpdatedAt = time.Now()
store.keys = append(store.keys, payload)
return &payload, nil
}

func (store *DB) ListJWTKeys(c context.Context) ([]common.JWTKey, error) {
return store.keys, nil
}

func (store *DB) RotateJWTKeys(c context.Context, idx uint) error {
out := []common.JWTKey{}
for _, key := range store.keys {
if key.ID != idx {
key.PrivateKey = nil
key.PrivateKeyAsBytes = nil
}
out = append(out, key)
}
// keep 3 latest ones
if len(out) > 3 {
store.keys = []common.JWTKey{out[len(out)-3], out[len(out)-2], out[len(out)-1]}
} else {
store.keys = out
}
return nil
}

func TestRotateKeys(t *testing.T) {
store := &DB{}
db, err := NewDatabase(store)
require.NoError(t, err)
require.Equal(t, 1, len(db.keys))

db2, err := NewDatabase(store)
require.NoError(t, err)
require.Equal(t, 1, len(db2.keys))
require.Equal(t, db.keys, db2.keys)

err = db.RotateKeys()
require.NoError(t, err)
require.Equal(t, 2, len(db.keys))
err = db.RotateKeys()
require.NoError(t, err)
require.Equal(t, 3, len(db.keys))
err = db.RotateKeys()
require.NoError(t, err)
require.Equal(t, 3, len(db.keys))

ids := make([]string, 0, len(db.keys))
for _, key := range db.keys {
ids = append(ids, key.KID)
}
// rotate all keys
for i := 0; i < 5; i++ {
err = db.RotateKeys()
require.NoError(t, err)
}

// check that mem.keys does not exist in ids
for _, key := range db.keys {
require.NotContains(t, ids, key.KID)
}
require.Equal(t, 3, len(db.keys))
}

func TestGetAndRefresh(t *testing.T) {
store := &DB{}
db, err := NewDatabase(store)
require.NoError(t, err)
require.Equal(t, 1, len(db.keys))
require.Equal(t, db.keys[0], db.GetCurrentKey())
require.Equal(t, db.keys, db.GetKeys())
data, err := db.RefreshKeys(true)
require.NoError(t, err)
require.Equal(t, db.keys, data)
}
74 changes: 74 additions & 0 deletions auth/memory/memory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package memory

import (
"sync"
"time"

"github.com/elisasre/go-common"

"github.com/rs/zerolog/log"
)

// Memory is an implementation of Interface for memory auth.
type Memory struct {
keys []common.JWTKey
keysMu sync.RWMutex
}

// NewMemory init new memory interface.
// Memory is used mainly for testing do NOT use in production.
func NewMemory() (*Memory, error) {
m := &Memory{}
err := m.RotateKeys()
if err != nil {
return nil, err
}
return m, nil
}

// RotateKeys rotates the jwt secrets.
func (m *Memory) RotateKeys() error {
m.keysMu.Lock()
defer m.keysMu.Unlock()
start := time.Now()
keys, err := common.GenerateNewKeyPair()
if err != nil {
return err
}
// private key is needed only in newest which are used to generate new tokens
for i := range m.keys {
m.keys[i].PrivateKey = nil
m.keys[i].PrivateKeyAsBytes = nil
}
m.keys = append([]common.JWTKey{*keys}, m.keys...)

// keep 3 latest public keys
if len(m.keys) > 3 {
m.keys = m.keys[0:3]
}
log.Info().
Str("duration", time.Since(start).String()).
Msg("rotate keys")
return nil
}

// GetKeys fetch all keys from cache.
func (m *Memory) GetKeys() []common.JWTKey {
m.keysMu.RLock()
defer m.keysMu.RUnlock()
data := make([]common.JWTKey, len(m.keys))
copy(data, m.keys)
return data
}

// GetCurrentKey fetch latest key from cache, it should have privatekey.
func (m *Memory) GetCurrentKey() common.JWTKey {
m.keysMu.RLock()
defer m.keysMu.RUnlock()
return m.keys[0]
}

// RefreshKeys refresh the keys from database.
func (m *Memory) RefreshKeys(reload bool) ([]common.JWTKey, error) {
return m.keys, nil
}
Loading

0 comments on commit 0496e26

Please sign in to comment.