From fce5997fda8847b04689d7d5684d9a4792bd654c Mon Sep 17 00:00:00 2001 From: seborama Date: Sat, 10 Sep 2016 02:09:21 +0100 Subject: [PATCH] [Fixes] Fixes & improvements --- govcr.go | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/govcr.go b/govcr.go index 430b663..621613d 100644 --- a/govcr.go +++ b/govcr.go @@ -7,6 +7,7 @@ import ( "log" "net/http" "os" + "strings" ) // VCRControlPanel holds the parts of a VCR that can be interacted with. @@ -92,11 +93,10 @@ func (pcbr *pcb) trackMatches(cassette *cassette, trackNumber int, req *http.Req // headerResembles compares HTTP headers for equivalence. func (pcbr *pcb) headerResembles(header1 http.Header, header2 http.Header) bool { - for k, v1 := range header1 { - for _, v2 := range v1 { - if header2.Get(k) != v2 && !pcbr.ExcludeHeaderFunc(k) { - return false - } + for k := range header1 { + // TODO: a given header may have several values (and in any order) + if GetFirstValue(header1, k) != GetFirstValue(header2, k) && !pcbr.ExcludeHeaderFunc(k) { + return false } } @@ -110,8 +110,8 @@ func (pcbr *pcb) bodyResembles(body1 string, body2 string) bool { return *pcbr.RequestBodyFilterFunc(body1) == *pcbr.RequestBodyFilterFunc(body2) } -func (pcbr *pcb) filterHeader(resp *http.Response) *http.Response { - resp.Header = *pcbr.ResponseHeaderFilterFunc(resp.Header) +func (pcbr *pcb) filterHeader(resp *http.Response, reqHdr http.Header) *http.Response { + resp.Header = *pcbr.ResponseHeaderFilterFunc(resp.Header, reqHdr) return resp } @@ -127,6 +127,22 @@ func (pcbr *pcb) filterBody(resp *http.Response) *http.Response { return resp } +// GetFirstValue is a utility function that extracts the first value of a header key. +// The reason for this function is that some servers require case sensitive headers which +// prevent the use of http.Header.Get() as it expects header keys to be canonicalized. +func GetFirstValue(hdr http.Header, key string) string { + for k, val := range hdr { + if strings.ToLower(k) == strings.ToLower(key) { + if len(val) > 0 { + return val[0] + } + return "" + } + } + + return "" +} + // NewVCR creates a new VCR and loads a cassette. // A RoundTripper can be provided when a custom Transport is needed (for example to provide // certificates, etc) @@ -155,7 +171,7 @@ func NewVCR(cassetteName string, vcrConfig *VCRConfig) *VCRControlPanel { // use a default set of FilterFunc's if vcrConfig.ExcludeHeaderFunc == nil { vcrConfig.ExcludeHeaderFunc = func(key string) bool { - return true + return false } } @@ -166,8 +182,8 @@ func NewVCR(cassetteName string, vcrConfig *VCRConfig) *VCRControlPanel { } if vcrConfig.ResponseHeaderFilterFunc == nil { - vcrConfig.ResponseHeaderFilterFunc = func(header http.Header) *http.Header { - return &header + vcrConfig.ResponseHeaderFilterFunc = func(respHdr http.Header, reqHdr http.Header) *http.Header { + return &respHdr } } @@ -239,7 +255,7 @@ type BodyFilterFunc func(string) *string // // It is important to note that this differs from ExcludeHeaderFunc in that the former does not // modify the header (it only returns a bool) whereas this function can be used to modify the header. -type HeaderFilterFunc func(header http.Header) *http.Header +type HeaderFilterFunc func(http.Header, http.Header) *http.Header // vcrTransport is the heart of VCR. It provides // an http.RoundTripper that wraps over the default @@ -272,7 +288,8 @@ func (t *vcrTransport) RoundTrip(req *http.Request) (*http.Response, error) { // attempt to use a track from the cassette that matches // the request if one exists. if trackNumber := t.PCB.seekTrack(t.Cassette, copiedReq); trackNumber != trackNotFound { - resp = t.PCB.filterHeader(t.PCB.filterBody(t.Cassette.replayResponse(trackNumber, copiedReq))) + // only the played back response is filtered. Never the live response! + resp = t.PCB.filterHeader(t.PCB.filterBody(t.Cassette.replayResponse(trackNumber, copiedReq)), copiedReq.Header) requestMatched = true }