Skip to content

Commit

Permalink
Enable switching a decrypted cassette to encryption mode (#78)
Browse files Browse the repository at this point in the history
Also, increase test coverage
  • Loading branch information
seborama authored Aug 15, 2022
1 parent 57704a7 commit 5a7addf
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 16 deletions.
17 changes: 11 additions & 6 deletions cassette/cassette.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (k7 *Cassette) IsLongPlay() bool {
return strings.HasSuffix(k7.name, ".gz")
}

func (k7 *Cassette) isEncrypted() bool {
func (k7 *Cassette) wantEncrypted() bool {
return k7.crypter != nil
}

Expand Down Expand Up @@ -191,7 +191,7 @@ func (k7 *Cassette) GunzipFilter(data []byte) ([]byte, error) {
// EncryptionFilter encrypts the cassette data if a cryptographer Crypter
// was supplied, otherwise data is left as is.
func (k7 *Cassette) EncryptionFilter(data []byte) ([]byte, error) {
if !k7.isEncrypted() {
if !k7.wantEncrypted() {
return data, nil
}

Expand Down Expand Up @@ -219,14 +219,19 @@ func (k7 *Cassette) EncryptionFilter(data []byte) ([]byte, error) {
func (k7 *Cassette) DecryptionFilter(data []byte) ([]byte, error) {
hasEncryptionMarker := bytes.HasPrefix(data, []byte(encryptedCassetteHeader))

if !k7.isEncrypted() {
if !k7.wantEncrypted() {
if hasEncryptionMarker {
return nil, cryptoerr.NewErrCrypto("cassette has encryption marker but no cryptographer was supplied")
}

return data, nil
}

if !hasEncryptionMarker {
// We're going off the chance that the cassette file is not encrypted yet but that from next save it should be.
return data, nil
}

return Decrypt(data, k7.crypter)
}

Expand Down Expand Up @@ -262,11 +267,11 @@ func (k7 *Cassette) readCassetteFile(cassetteName string) error {

gData, err := k7.GunzipFilter(dData)
if err != nil {
return err
return errors.WithStack(err)
}

// NOTE: Properties which are of type 'interface{} / any' are not handled very well
if err := json.Unmarshal(gData, k7); err != nil {
if err = json.Unmarshal(gData, k7); err != nil {
return errors.Wrap(err, "failed to interpret cassette data in file")
}

Expand Down Expand Up @@ -316,7 +321,7 @@ func LoadCassette(cassetteName string, opts ...Option) *Cassette {

err := k7.readCassetteFile(cassetteName)
if err != nil {
panic(fmt.Sprintf("unable to load corrupted cassette '%s': %v", cassetteName, err))
panic(fmt.Sprintf("unable to load corrupted cassette '%s': %+v", cassetteName, err))
}

// initial stats
Expand Down
62 changes: 60 additions & 2 deletions cassette/cassette_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ func Test_cassette_Encryption(t *testing.T) {

_ = os.Remove(cassetteName)

// STEP 1: create encrypted cassette.
key := []byte("12345678901234567890123456789012")
c, err := encryption.NewAESCGM(key, nil)
require.NoError(t, err)
Expand All @@ -138,11 +139,13 @@ func Test_cassette_Encryption(t *testing.T) {
err = cassette.AddTrackToCassette(k7, trk)
require.NoError(t, err)

// STEP 2: ensure cassette loads.
var k8 *cassette.Cassette
require.NotPanics(t, func() {
k8 = cassette.LoadCassette(cassetteName, cassette.WithCassetteCrypter(c))
})

// STEP 3: perform high and low-level validation checks on cassette file.
data, err := os.ReadFile(cassetteName) // nolint:gosec
require.NoError(t, err)

Expand All @@ -157,8 +160,63 @@ func Test_cassette_Encryption(t *testing.T) {

require.Equal(t, k7.NumberOfTracks(), k8.NumberOfTracks())

for i := range k8.Tracks {
k8.Tracks[i].SetReplayed(true) // so to match k7
for i := range k7.Tracks {
k7.Tracks[i].SetReplayed(false) // so to match k8
}

require.Equal(t, k7.Tracks, k8.Tracks)
}

func Test_cassette_CanEncryptPlainCassette(t *testing.T) {
const cassetteName = "temp-fixtures/Test_cassette_CanEncryptPlainCassette"

_ = os.Remove(cassetteName)

// STEP 1a: create a non-encrypted cassette.
// This is not required for cassette encryption, this is for the purpose of confirming
// that a non-encrypted cassette will convert to an encrypted cassette seamlessly.
k7 := cassette.NewCassette(cassetteName)

trk := &track.Track{UUID: "trk-1"}

err := cassette.AddTrackToCassette(k7, trk)
require.NoError(t, err)

// STEP 1b: add track to cassette, this time encrypt the cassette.
key := []byte("12345678901234567890123456789012")
c, err := encryption.NewAESCGM(key, nil)
require.NoError(t, err)

k7 = cassette.LoadCassette(cassetteName, cassette.WithCassetteCrypter(c))

trk = &track.Track{UUID: "trk-2"}

err = cassette.AddTrackToCassette(k7, trk)
require.NoError(t, err)

// STEP 2: ensure cassette loads.
var k8 *cassette.Cassette
require.NotPanics(t, func() {
k8 = cassette.LoadCassette(cassetteName, cassette.WithCassetteCrypter(c))
})

// STEP 3: perform high and low-level validation checks on cassette file.
data, err := os.ReadFile(cassetteName) // nolint:gosec
require.NoError(t, err)

const encryptedCassetteHeader = "$ENC$"

require.True(t, bytes.HasPrefix(data, []byte(encryptedCassetteHeader)))

nonceLen := int(data[len(encryptedCassetteHeader)])
nonce := data[len(encryptedCassetteHeader)+1 : len(encryptedCassetteHeader)+1+nonceLen]

t.Logf("nonce: %x\n", nonce)

require.Equal(t, k7.NumberOfTracks(), k8.NumberOfTracks())

for i := range k7.Tracks {
k7.Tracks[i].SetReplayed(false) // so to match k8
}

require.Equal(t, k7.Tracks, k8.Tracks)
Expand Down
23 changes: 23 additions & 0 deletions cassette/track/http_wb_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package track

import (
"net/url"
"testing"

"github.com/stretchr/testify/assert"
)

func Test_cloneURLValues(t *testing.T) {
unit := url.Values{
"one": {
"one.1", "one.2",
},
"two": {
"two.1", "",
},
"three": {},
}

got := cloneURLValues(unit)
assert.Equal(t, unit, got)
}
88 changes: 88 additions & 0 deletions controlpanel_wb_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package govcr

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"

"github.com/seborama/govcr/v8/cassette/track"
)

func TestControlPanel_SetRecordingMutators(t *testing.T) {
unit := &ControlPanel{
client: &http.Client{
Transport: &vcrTransport{
pcb: &PrintedCircuitBoard{
trackRecordingMutators: track.Mutators{
track.TrackRequestAddHeaderValue("k", "v"),
track.TrackRequestDeleteHeaderKeys("k2"),
},
},
},
},
}

unit.SetRecordingMutators(track.TrackRequestDeleteHeaderKeys("k1"))

assert.Len(t, unit.client.Transport.(*vcrTransport).pcb.trackRecordingMutators, 1)
assert.Len(t, unit.client.Transport.(*vcrTransport).pcb.trackReplayingMutators, 0)
}

func TestControlPanel_AddRecordingMutators(t *testing.T) {
unit := &ControlPanel{
client: &http.Client{
Transport: &vcrTransport{
pcb: &PrintedCircuitBoard{
trackRecordingMutators: track.Mutators{
track.TrackRequestAddHeaderValue("k", "v"),
},
},
},
},
}

unit.AddRecordingMutators(track.TrackRequestDeleteHeaderKeys("k2"))

assert.Len(t, unit.client.Transport.(*vcrTransport).pcb.trackRecordingMutators, 2)
assert.Len(t, unit.client.Transport.(*vcrTransport).pcb.trackReplayingMutators, 0)
}

func TestControlPanel_SetReplayingMutators(t *testing.T) {
unit := &ControlPanel{
client: &http.Client{
Transport: &vcrTransport{
pcb: &PrintedCircuitBoard{
trackReplayingMutators: track.Mutators{
track.TrackRequestAddHeaderValue("k", "v"),
track.TrackRequestDeleteHeaderKeys("k2"),
},
},
},
},
}

unit.SetReplayingMutators(track.TrackRequestDeleteHeaderKeys("k1"))

assert.Len(t, unit.client.Transport.(*vcrTransport).pcb.trackReplayingMutators, 1)
assert.Len(t, unit.client.Transport.(*vcrTransport).pcb.trackRecordingMutators, 0)
}

func TestControlPanel_AddReplayingMutators(t *testing.T) {
unit := &ControlPanel{
client: &http.Client{
Transport: &vcrTransport{
pcb: &PrintedCircuitBoard{
trackReplayingMutators: track.Mutators{
track.TrackRequestAddHeaderValue("k", "v"),
},
},
},
},
}

unit.AddReplayingMutators(track.TrackRequestDeleteHeaderKeys("k2"))

assert.Len(t, unit.client.Transport.(*vcrTransport).pcb.trackReplayingMutators, 2)
assert.Len(t, unit.client.Transport.(*vcrTransport).pcb.trackRecordingMutators, 0)
}
6 changes: 2 additions & 4 deletions govcr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ func TestVCRControlPanel_LoadCassette_WhenOneIsAlreadyLoaded(t *testing.T) {

func TestVCRControlPanel_LoadCassette_InvalidCassette(t *testing.T) {
unit := govcr.NewVCR()
assert.PanicsWithValue(
assert.Panics(
t,
"unable to load corrupted cassette 'test-fixtures/bad.cassette.json': failed to interpret cassette data in file: invalid character 'T' looking for beginning of value",
func() {
_ = unit.LoadCassette("test-fixtures/bad.cassette.json")
})
Expand All @@ -70,9 +69,8 @@ func TestVCRControlPanel_LoadCassette_UnreadableCassette(t *testing.T) {
createUnreadableCassette(t, cassetteName)

unit := govcr.NewVCR()
assert.PanicsWithValue(
assert.Panics(
t,
"unable to load corrupted cassette '"+cassetteName+"': failed to read cassette data from file: open "+cassetteName+": permission denied",
func() {
_ = unit.LoadCassette(cassetteName)
})
Expand Down
9 changes: 5 additions & 4 deletions vcrsettings.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package govcr

import (
"fmt"
"net/http"
"os"

Expand Down Expand Up @@ -33,12 +34,12 @@ func WithCassetteCrypto(keyFile string) CassetteOption {
return func(cfg *CassetteConfig) {
key, err := os.ReadFile(keyFile)
if err != nil {
panic(err)
panic(fmt.Sprintf("%+v", err))
}

crypter, err := encryption.NewAESCGM(key, nil)
if err != nil {
panic(err)
panic(fmt.Sprintf("%+v", err))
}

cfg.Crypter = crypter
Expand All @@ -51,12 +52,12 @@ func WithCassetteCryptoCustomNonce(keyFile string, nonceGenerator encryption.Non
return func(cfg *CassetteConfig) {
key, err := os.ReadFile(keyFile)
if err != nil {
panic(err)
panic(fmt.Sprintf("%+v", err))
}

crypter, err := encryption.NewAESCGM(key, nonceGenerator)
if err != nil {
panic(err)
panic(fmt.Sprintf("%+v", err))
}

cfg.Crypter = crypter
Expand Down
28 changes: 28 additions & 0 deletions vcrsettings_wb_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package govcr

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestWithLiveOnlyMode(t *testing.T) {
vcrSettings := &VCRSettings{}

WithLiveOnlyMode()(vcrSettings)
assert.Equal(t, HTTPModeLiveOnly, vcrSettings.httpMode)
}

func TestWithOfflineMode(t *testing.T) {
vcrSettings := &VCRSettings{}

WithOfflineMode()(vcrSettings)
assert.Equal(t, HTTPModeOffline, vcrSettings.httpMode)
}

func TestW(t *testing.T) {
vcrSettings := &VCRSettings{}

WithReadOnlyMode()(vcrSettings)
assert.True(t, vcrSettings.readOnly)
}

0 comments on commit 5a7addf

Please sign in to comment.