Skip to content

Commit

Permalink
Ring-fenced concurrency safety for cassette and pcb
Browse files Browse the repository at this point in the history
Also, updated tests and added a concurrency safety test
  • Loading branch information
seborama committed Mar 3, 2019
1 parent 5cbc6ef commit 4627395
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 40 deletions.
20 changes: 13 additions & 7 deletions cassette.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,14 @@ func (k7 *cassette) addTrack(track *track) {
k7.Tracks = append(k7.Tracks, *track)
}

// Track retrieves the requested track number.
// '0' is the first track.
func (k7 *cassette) Track(trackNumber int32) track {
k7.trackSliceMutex.RLock()
defer k7.trackSliceMutex.RUnlock()
return k7.Tracks[trackNumber]
}

// Stats returns the cassette's Stats.
func (k7 *cassette) Stats() Stats {
stats := Stats{}
Expand All @@ -328,14 +336,12 @@ func (k7 *cassette) Stats() Stats {
func (k7 *cassette) tracksPlayed() int32 {
replayed := int32(0)

{
k7.trackSliceMutex.RLock()
defer k7.trackSliceMutex.RUnlock()
k7.trackSliceMutex.RLock()
defer k7.trackSliceMutex.RUnlock()

for _, t := range k7.Tracks {
if t.replayed {
replayed++
}
for _, t := range k7.Tracks {
if t.replayed {
replayed++
}
}

Expand Down
189 changes: 157 additions & 32 deletions govcr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@ import (
"crypto/tls"
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"testing"

"net/http/httptest"
"strconv"
"testing"
"time"

"github.com/seborama/govcr"
)

func TestPlaybackOrder(t *testing.T) {
cassetteName := "TestPlaybackOrder"
clientNum := 1
clientNum := int8(1)

// create a test server
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -38,13 +40,17 @@ func TestPlaybackOrder(t *testing.T) {

// check outcome of the request
expectedBody := fmt.Sprintf("Hello, client %d", i)
checkResponseForTestPlaybackOrder(t, resp, expectedBody)
if err := validateResponseForTestPlaybackOrder(resp, expectedBody); err != nil {
t.Fatal(err.Error())
}

if !govcr.CassetteExistsAndValid(cassetteName, "") {
t.Fatalf("CassetteExists: expected true, got false")
}

checkStats(t, vcr.Stats(), 0, i, 0)
if err := validateStats(vcr.Stats(), 0, i, 0); err != nil {
t.Fatal(err.Error())
}
}

fmt.Println("Phase 2 - Playback =====================================")
Expand All @@ -60,19 +66,23 @@ func TestPlaybackOrder(t *testing.T) {

// check outcome of the request
expectedBody := fmt.Sprintf("Hello, client %d", i)
checkResponseForTestPlaybackOrder(t, resp, expectedBody)
if err := validateResponseForTestPlaybackOrder(resp, expectedBody); err != nil {
t.Fatal(err.Error())
}

if !govcr.CassetteExistsAndValid(cassetteName, "") {
t.Fatalf("CassetteExists: expected true, got false")
}

checkStats(t, vcr.Stats(), 10, 0, i)
if err := validateStats(vcr.Stats(), 10, 0, i); err != nil {
t.Fatal(err.Error())
}
}
}

func TestNonUtf8EncodableBinaryBody(t *testing.T) {
cassetteName := "TestNonUtf8EncodableBinaryBody"
clientNum := int32(1)
clientNum := int8(1)

// create a test server
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -97,18 +107,22 @@ func TestNonUtf8EncodableBinaryBody(t *testing.T) {
client := vcr.Client

// run requests
for i := int32(1); i <= 10; i++ {
for i := int8(1); i <= 10; i++ {
resp, _ := client.Get(ts.URL)

// check outcome of the request
expectedBody := generateBinaryBody(i)
checkResponseForTestPlaybackOrder(t, resp, expectedBody)
if err := validateResponseForTestPlaybackOrder(resp, expectedBody); err != nil {
t.Fatal(err.Error())
}

if !govcr.CassetteExistsAndValid(cassetteName, "") {
t.Fatalf("CassetteExists: expected true, got false")
}

checkStats(t, vcr.Stats(), 0, i, 0)
if err := validateStats(vcr.Stats(), 0, int32(i), 0); err != nil {
t.Fatal(err.Error())
}
}

fmt.Println("Phase 2 - Playback =====================================")
Expand All @@ -123,20 +137,24 @@ func TestNonUtf8EncodableBinaryBody(t *testing.T) {
resp, _ := client.Get(ts.URL)

// check outcome of the request
expectedBody := generateBinaryBody(i)
checkResponseForTestPlaybackOrder(t, resp, expectedBody)
expectedBody := generateBinaryBody(int8(i))
if err := validateResponseForTestPlaybackOrder(resp, expectedBody); err != nil {
t.Fatal(err.Error())
}

if !govcr.CassetteExistsAndValid(cassetteName, "") {
t.Fatalf("CassetteExists: expected true, got false")
}

checkStats(t, vcr.Stats(), 10, 0, i)
if err := validateStats(vcr.Stats(), 10, 0, i); err != nil {
t.Fatal(err.Error())
}
}
}

func TestLongPlay(t *testing.T) {
cassetteName := t.Name() + ".gz"
clientNum := 1
clientNum := int8(1)

// create a test server
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -159,13 +177,17 @@ func TestLongPlay(t *testing.T) {

// check outcome of the request
expectedBody := fmt.Sprintf("Hello, client %d", i)
checkResponseForTestPlaybackOrder(t, resp, expectedBody)
if err := validateResponseForTestPlaybackOrder(resp, expectedBody); err != nil {
t.Fatal(err.Error())
}

if !govcr.CassetteExistsAndValid(cassetteName, "") {
t.Fatalf("CassetteExists: expected true, got false")
}

checkStats(t, vcr.Stats(), 0, i, 0)
if err := validateStats(vcr.Stats(), 0, i, 0); err != nil {
t.Fatal(err.Error())
}
}

fmt.Println("Phase 2 - Playback =====================================")
Expand All @@ -181,13 +203,112 @@ func TestLongPlay(t *testing.T) {

// check outcome of the request
expectedBody := fmt.Sprintf("Hello, client %d", i)
checkResponseForTestPlaybackOrder(t, resp, expectedBody)
if err := validateResponseForTestPlaybackOrder(resp, expectedBody); err != nil {
t.Fatal(err.Error())
}

if !govcr.CassetteExistsAndValid(cassetteName, "") {
t.Fatalf("CassetteExists: expected true, got false")
}

checkStats(t, vcr.Stats(), 10, 0, i)
if err := validateStats(vcr.Stats(), 10, 0, i); err != nil {
t.Fatal(err.Error())
}
}
}

func TestConcurrencySafety(t *testing.T) {
cassetteName := "TestConcurrencySafety"
threadMax := int8(50)

// create a test server
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Millisecond * time.Duration(rand.Intn(50)))

clientNum, _ := strconv.ParseInt(r.URL.Query().Get("num"), 0, 8)

data := generateBinaryBody(int8(clientNum))
written, err := w.Write(data)
if written != len(data) {
t.Fatalf("** Only %d bytes out of %d were written", written, len(data))
}
if err != nil {
t.Fatalf("err from w.Write(): Expected nil, got %s", err)
}
}))

fmt.Println("Phase 1 ================================================")

if err := govcr.DeleteCassette(cassetteName, ""); err != nil {
t.Fatalf("err from govcr.DeleteCassette(): Expected nil, got %s", err)
}

vcr := createVCR(cassetteName, false)
client := vcr.Client

t.Run("main - phase 1", func(t *testing.T) {
// run requests
for i := int8(1); i <= threadMax; i++ {
func(i1 int8) {
t.Run(fmt.Sprintf("i=%d", i), func(t *testing.T) {
t.Parallel()

func() {
resp, _ := client.Get(fmt.Sprintf("%s?num=%d", ts.URL, i1))

// check outcome of the request
expectedBody := generateBinaryBody(i1)
if err := validateResponseForTestPlaybackOrder(resp, expectedBody); err != nil {
t.Fatalf(err.Error())
}

if !govcr.CassetteExistsAndValid(cassetteName, "") {
t.Fatalf("CassetteExists: expected true, got false")
}
}()
})
}(i)
}
})

if err := validateStats(vcr.Stats(), 0, int32(threadMax), 0); err != nil {
t.Fatal(err.Error())
}

fmt.Println("Phase 2 - Playback =====================================")

// re-run request and expect play back from vcr
vcr = createVCR(cassetteName, false)
client = vcr.Client

// run requests
t.Run("main - phase 1", func(t *testing.T) {
// run requests
for i := int8(1); i <= threadMax; i++ {
func(i1 int8) {
t.Run(fmt.Sprintf("i=%d", i), func(t *testing.T) {
t.Parallel()

func() {
resp, _ := client.Get(fmt.Sprintf("%s?num=%d", ts.URL, i1))

// check outcome of the request
expectedBody := generateBinaryBody(i1)
if err := validateResponseForTestPlaybackOrder(resp, expectedBody); err != nil {
t.Fatalf(err.Error())
}

if !govcr.CassetteExistsAndValid(cassetteName, "") {
t.Fatalf("CassetteExists: expected true, got false")
}
}()
})
}(i)
}
})

if err := validateStats(vcr.Stats(), int32(threadMax), 0, int32(threadMax)); err != nil {
t.Fatal(err.Error())
}
}

Expand All @@ -206,18 +327,18 @@ func createVCR(cassetteName string, lp bool) *govcr.VCRControlPanel {
})
}

func checkResponseForTestPlaybackOrder(t *testing.T, resp *http.Response, expectedBody interface{}) {
func validateResponseForTestPlaybackOrder(resp *http.Response, expectedBody interface{}) error {
if resp.StatusCode != http.StatusOK {
t.Fatalf("resp.StatusCode: Expected %d, got %d", http.StatusOK, resp.StatusCode)
return fmt.Errorf("resp.StatusCode: Expected %d, got %d", http.StatusOK, resp.StatusCode)
}

if resp.Body == nil {
t.Fatalf("resp.Body: Expected non-nil, got nil")
return fmt.Errorf("resp.Body: Expected non-nil, got nil")
}

bodyData, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("err from ioutil.ReadAll(): Expected nil, got %s", err)
return fmt.Errorf("err from ioutil.ReadAll(): Expected nil, got %s", err)
}
resp.Body.Close()

Expand All @@ -227,40 +348,44 @@ func checkResponseForTestPlaybackOrder(t *testing.T, resp *http.Response, expect
var ok bool
expectedBodyBytes, ok = expectedBody.([]byte)
if !ok {
t.Fatalf("expectedBody: cannot assert to type '[]byte'")
return fmt.Errorf("expectedBody: cannot assert to type '[]byte'")
}

case string:
expectedBodyString, ok := expectedBody.(string)
if !ok {
t.Fatalf("expectedBody: cannot assert to type 'string'")
return fmt.Errorf("expectedBody: cannot assert to type 'string'")
}
expectedBodyBytes = []byte(expectedBodyString)

default:
t.Fatalf("Unexpected type for 'expectedBody' variable")
return fmt.Errorf("Unexpected type for 'expectedBody' variable")
}

if !bytes.Equal(bodyData, expectedBodyBytes) {
t.Fatalf("Body: expected '%s', got '%s'", expectedBody, bodyData)
return fmt.Errorf("Body: expected '%v', got '%v'", expectedBody, bodyData)
}

return nil
}

func checkStats(t *testing.T, actualStats govcr.Stats, expectedTracksLoaded, expectedTracksRecorded, expectedTrackPlayed int32) {
func validateStats(actualStats govcr.Stats, expectedTracksLoaded, expectedTracksRecorded, expectedTrackPlayed int32) error {
if actualStats.TracksLoaded != expectedTracksLoaded {
t.Fatalf("Expected %d track loaded, got %d", expectedTracksLoaded, actualStats.TracksLoaded)
return fmt.Errorf("Expected %d track loaded, got %d", expectedTracksLoaded, actualStats.TracksLoaded)
}

if actualStats.TracksRecorded != expectedTracksRecorded {
t.Fatalf("Expected %d track recorded, got %d", expectedTracksRecorded, actualStats.TracksRecorded)
return fmt.Errorf("Expected %d track recorded, got %d", expectedTracksRecorded, actualStats.TracksRecorded)
}

if actualStats.TracksPlayed != expectedTrackPlayed {
t.Fatalf("Expected %d track played, got %d", expectedTrackPlayed, actualStats.TracksPlayed)
return fmt.Errorf("Expected %d track played, got %d", expectedTrackPlayed, actualStats.TracksPlayed)
}

return nil
}

func generateBinaryBody(sequence int32) []byte {
func generateBinaryBody(sequence int8) []byte {
data := make([]byte, 256, 257)
for i := range data {
data[i] = byte(i)
Expand Down
2 changes: 1 addition & 1 deletion pcb.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (pcbr *pcb) seekTrack(cassette *cassette, req *http.Request) (*http.Respons

// Matches checks whether the track is a match for the supplied request.
func (pcbr *pcb) trackMatches(cassette *cassette, trackNumber int32, request Request) bool {
track := cassette.Tracks[trackNumber]
track := cassette.Track(trackNumber)

// apply filter function to track header / body
filteredTrackRequest := pcbr.RequestFilter(track.Request.Request())
Expand Down

0 comments on commit 4627395

Please sign in to comment.