Skip to content

Commit

Permalink
Merge pull request #22 from seborama/Fixes
Browse files Browse the repository at this point in the history
[Fixes] Fixes & improvements
  • Loading branch information
seborama authored Sep 10, 2016
2 parents 073b5fd + fce5997 commit ab09844
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions govcr.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"log"
"net/http"
"os"
"strings"
)

// VCRControlPanel holds the parts of a VCR that can be interacted with.
Expand Down Expand Up @@ -92,11 +93,10 @@ func (pcbr *pcb) trackMatches(cassette *cassette, trackNumber int, req *http.Req

// headerResembles compares HTTP headers for equivalence.
func (pcbr *pcb) headerResembles(header1 http.Header, header2 http.Header) bool {
for k, v1 := range header1 {
for _, v2 := range v1 {
if header2.Get(k) != v2 && !pcbr.ExcludeHeaderFunc(k) {
return false
}
for k := range header1 {
// TODO: a given header may have several values (and in any order)
if GetFirstValue(header1, k) != GetFirstValue(header2, k) && !pcbr.ExcludeHeaderFunc(k) {
return false
}
}

Expand All @@ -110,8 +110,8 @@ func (pcbr *pcb) bodyResembles(body1 string, body2 string) bool {
return *pcbr.RequestBodyFilterFunc(body1) == *pcbr.RequestBodyFilterFunc(body2)
}

func (pcbr *pcb) filterHeader(resp *http.Response) *http.Response {
resp.Header = *pcbr.ResponseHeaderFilterFunc(resp.Header)
func (pcbr *pcb) filterHeader(resp *http.Response, reqHdr http.Header) *http.Response {
resp.Header = *pcbr.ResponseHeaderFilterFunc(resp.Header, reqHdr)
return resp
}

Expand All @@ -127,6 +127,22 @@ func (pcbr *pcb) filterBody(resp *http.Response) *http.Response {
return resp
}

// GetFirstValue is a utility function that extracts the first value of a header key.
// The reason for this function is that some servers require case sensitive headers which
// prevent the use of http.Header.Get() as it expects header keys to be canonicalized.
func GetFirstValue(hdr http.Header, key string) string {
for k, val := range hdr {
if strings.ToLower(k) == strings.ToLower(key) {
if len(val) > 0 {
return val[0]
}
return ""
}
}

return ""
}

// NewVCR creates a new VCR and loads a cassette.
// A RoundTripper can be provided when a custom Transport is needed (for example to provide
// certificates, etc)
Expand Down Expand Up @@ -155,7 +171,7 @@ func NewVCR(cassetteName string, vcrConfig *VCRConfig) *VCRControlPanel {
// use a default set of FilterFunc's
if vcrConfig.ExcludeHeaderFunc == nil {
vcrConfig.ExcludeHeaderFunc = func(key string) bool {
return true
return false
}
}

Expand All @@ -166,8 +182,8 @@ func NewVCR(cassetteName string, vcrConfig *VCRConfig) *VCRControlPanel {
}

if vcrConfig.ResponseHeaderFilterFunc == nil {
vcrConfig.ResponseHeaderFilterFunc = func(header http.Header) *http.Header {
return &header
vcrConfig.ResponseHeaderFilterFunc = func(respHdr http.Header, reqHdr http.Header) *http.Header {
return &respHdr
}
}

Expand Down Expand Up @@ -239,7 +255,7 @@ type BodyFilterFunc func(string) *string
//
// It is important to note that this differs from ExcludeHeaderFunc in that the former does not
// modify the header (it only returns a bool) whereas this function can be used to modify the header.
type HeaderFilterFunc func(header http.Header) *http.Header
type HeaderFilterFunc func(http.Header, http.Header) *http.Header

// vcrTransport is the heart of VCR. It provides
// an http.RoundTripper that wraps over the default
Expand Down Expand Up @@ -272,7 +288,8 @@ func (t *vcrTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// attempt to use a track from the cassette that matches
// the request if one exists.
if trackNumber := t.PCB.seekTrack(t.Cassette, copiedReq); trackNumber != trackNotFound {
resp = t.PCB.filterHeader(t.PCB.filterBody(t.Cassette.replayResponse(trackNumber, copiedReq)))
// only the played back response is filtered. Never the live response!
resp = t.PCB.filterHeader(t.PCB.filterBody(t.Cassette.replayResponse(trackNumber, copiedReq)), copiedReq.Header)
requestMatched = true
}

Expand Down

0 comments on commit ab09844

Please sign in to comment.