From 0b4c75e9bd3d812185d90335f41bcd78d228504f Mon Sep 17 00:00:00 2001 From: Martin Sucha Date: Sat, 24 Oct 2020 16:00:17 +0200 Subject: [PATCH] Add support for passing context When serving data from remote location, one might need to pass the request context to backend storage of entries, for example for distributed tracing to work. --- README.md | 13 +-- archive.go | 87 +++++++++---------- go.mod | 2 - go.sum | 2 - io.go | 123 ++++++++++++++++++++++++++ io_test.go | 244 ++++++++++++++++++++++++++++++++++++++++++++++++++++ struct.go | 3 + zip_test.go | 10 +-- 8 files changed, 424 insertions(+), 60 deletions(-) create mode 100644 io.go create mode 100644 io_test.go diff --git a/README.md b/README.md index 87ed993..48bf349 100644 --- a/README.md +++ b/README.md @@ -8,10 +8,12 @@ zipserve Package zipserve implements serving virtual zip archives over HTTP, with support for range queries and resumable downloads. Zipserve keeps only the archive headers in memory (similar to archive/zip when streaming). -The actual file data is fetched on demand from user-provided ReaderAt, -so the file data can be fetched from a remote location. -Zipserve needs to know CRC32 of the uncompressed data, compressed and uncompressed size of files in advance, -which must be supplied by the user. +Zipserve fetches file data on demand from user-provided `io.ReaderAt` or `zipserve.ReaderAt`, +so the file data can be fetched from a remote location. +`zipserve.ReaderAt` supports passing request context to the backing store. + +The user has to provide CRC32 of the uncompressed data, compressed and uncompressed size of files in advance. +These can be computed for example during file uploads. Differences to archive/zip -------------------------- @@ -35,8 +37,7 @@ so there aren't many commits. I update the module when a new version of Go is re License ------- -Three clause BSD (same as Go) for files in this package (see [LICENSE](LICENSE)), -Apache 2.0 for readerutil package from go4.org which is used as a dependency. +Three clause BSD (same as Go), see [LICENSE](LICENSE). Alternatives ------------ diff --git a/archive.go b/archive.go index 6e246a6..b505699 100644 --- a/archive.go +++ b/archive.go @@ -17,12 +17,12 @@ package zipserve import ( "bytes" + "context" "crypto/md5" "encoding/binary" "encoding/hex" "errors" "fmt" - "go4.org/readerutil" "io" "net/http" "strings" @@ -34,6 +34,9 @@ type Template struct { // Prefix is the content at the beginning of the file before ZIP entries. // // It may be used to create self-extracting archives, for example. + // + // Prefix may implement ReaderAt interface from this package, in that case + // Prefix's ReadAtContext method will be called instead of ReadAt. Prefix io.ReaderAt // PrefixSize is size of Prefix in bytes. @@ -54,25 +57,11 @@ type Template struct { CreateTime time.Time } -type partsBuilder struct { - parts []readerutil.SizeReaderAt - offset int64 -} - -func (pb *partsBuilder) add(r readerutil.SizeReaderAt) { - size := r.Size() - if size == 0 { - return - } - pb.parts = append(pb.parts, r) - pb.offset += size -} - // Archive represents the ZIP file data to be downloaded by the user. // // It is a ReaderAt, so allows concurrent access to different byte ranges of the archive. type Archive struct { - data readerutil.SizeReaderAt + parts multiReaderAt createTime time.Time etag string } @@ -89,9 +78,9 @@ func NewArchive(t *Template) (*Archive, error) { return newArchive(t, bufferView, nil) } -type bufferViewFunc func(content func(w io.Writer) error) (readerutil.SizeReaderAt, error) +type bufferViewFunc func(content func(w io.Writer) error) (sizeReaderAt, error) -func bufferView(content func(w io.Writer) error) (readerutil.SizeReaderAt, error) { +func bufferView(content func(w io.Writer) error) (sizeReaderAt, error) { var buf bytes.Buffer err := content(&buf) @@ -101,17 +90,24 @@ func bufferView(content func(w io.Writer) error) (readerutil.SizeReaderAt, error return bytes.NewReader(buf.Bytes()), nil } +func readerAt(r io.ReaderAt) ReaderAt { + if v, ok := r.(ReaderAt); ok { + return v + } + return ignoreContext{r: r} +} + func newArchive(t *Template, view bufferViewFunc, testHookCloseSizeOffset func(size, offset uint64)) (*Archive, error) { if len(t.Comment) > uint16max { return nil, errors.New("comment too long") } + ar := new(Archive) dir := make([]*header, 0, len(t.Entries)) - var pb partsBuilder etagHash := md5.New() if t.Prefix != nil { - pb.add(&addsize{size: t.PrefixSize, source: t.Prefix}) + ar.parts.add(readerAt(t.Prefix), t.PrefixSize) var buf [8]byte binary.LittleEndian.PutUint64(buf[:], uint64(t.PrefixSize)) @@ -122,14 +118,14 @@ func newArchive(t *Template, view bufferViewFunc, testHookCloseSizeOffset func(s for _, entry := range t.Entries { prepareEntry(entry) - dir = append(dir, &header{FileHeader: entry, offset: uint64(pb.offset)}) + dir = append(dir, &header{FileHeader: entry, offset: uint64(ar.parts.size)}) header, err := view(func(w io.Writer) error { return writeHeader(w, entry) }) if err != nil { return nil, err } - pb.add(header) + ar.parts.addSizeReaderAt(header) io.Copy(etagHash, io.NewSectionReader(header, 0, header.Size())) if strings.HasSuffix(entry.Name, "/") { if entry.Content != nil { @@ -137,13 +133,13 @@ func newArchive(t *Template, view bufferViewFunc, testHookCloseSizeOffset func(s } } else { if entry.Content != nil { - pb.add(&addsize{size: int64(entry.CompressedSize64), source: entry.Content}) + ar.parts.add(readerAt(entry.Content), int64(entry.CompressedSize64)) } else if entry.CompressedSize64 != 0 { return nil, errors.New("empty entry with nonzero length") } // data descriptor dataDescriptor := makeDataDescriptor(entry) - pb.add(bytes.NewReader(dataDescriptor)) + ar.parts.addSizeReaderAt(bytes.NewReader(dataDescriptor)) etagHash.Write(dataDescriptor) } if entry.Modified.After(maxTime) { @@ -153,7 +149,7 @@ func newArchive(t *Template, view bufferViewFunc, testHookCloseSizeOffset func(s // capture central directory offset and comment so that content func for central directory // may be called multiple times and we don't store reference to t in the closure - centralDirectoryOffset := pb.offset + centralDirectoryOffset := ar.parts.size comment := t.Comment centralDirectory, err := view(func(w io.Writer) error { return writeCentralDirectory(centralDirectoryOffset, dir, w, comment, testHookCloseSizeOffset) @@ -161,29 +157,38 @@ func newArchive(t *Template, view bufferViewFunc, testHookCloseSizeOffset func(s if err != nil { return nil, err } - pb.add(centralDirectory) + ar.parts.addSizeReaderAt(centralDirectory) io.Copy(etagHash, io.NewSectionReader(centralDirectory, 0, centralDirectory.Size())) - createTime := t.CreateTime - if createTime.IsZero() { - createTime = maxTime + ar.createTime = t.CreateTime + if ar.createTime.IsZero() { + ar.createTime = maxTime } - etag := fmt.Sprintf("\"%s\"", hex.EncodeToString(etagHash.Sum(nil))) + ar.etag = fmt.Sprintf("\"%s\"", hex.EncodeToString(etagHash.Sum(nil))) - return &Archive{ - data: readerutil.NewMultiReaderAt(pb.parts...), - createTime: createTime, - etag: etag}, nil + return ar, nil } // Size returns the size of the archive in bytes. -func (ar *Archive) Size() int64 { return ar.data.Size() } +func (ar *Archive) Size() int64 { return ar.parts.Size() } // ReadAt provides the data of the file. // +// This is same as calling ReadAtContext with context.TODO() +// // See io.ReaderAt for the interface. -func (ar *Archive) ReadAt(p []byte, off int64) (int, error) { return ar.data.ReadAt(p, off) } +func (ar *Archive) ReadAt(p []byte, off int64) (int, error) { return ar.parts.ReadAtContext(context.TODO(), p, off) } + +// ReadAtContext provides the data of the file. +// +// This methods implements ReaderAt interface. +// +// The context is passed to ReadAtContext of individual entries, if they implement it. The context is ignored if an +// entry implements just io.ReaderAt. +func (ar *Archive) ReadAtContext(ctx context.Context, p []byte, off int64) (int, error) { + return ar.parts.ReadAtContext(ctx, p, off) +} // ServeHTTP serves the archive over HTTP. // @@ -202,14 +207,6 @@ func (ar *Archive) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("Etag", ar.etag) } - readseeker := io.NewSectionReader(ar.data, 0, ar.data.Size()) + readseeker := io.NewSectionReader(withContext{r: &ar.parts, ctx: r.Context()}, 0, ar.parts.Size()) http.ServeContent(w, r, "", ar.createTime, readseeker) } - -type addsize struct { - size int64 - source io.ReaderAt -} - -func (as *addsize) Size() int64 { return as.size } -func (as *addsize) ReadAt(p []byte, off int64) (int, error) { return as.source.ReadAt(p, off) } diff --git a/go.mod b/go.mod index f479181..b916792 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,3 @@ module github.com/martin-sucha/zipserve go 1.12 - -require go4.org v0.0.0-20180417224846-9599cf28b011 diff --git a/go.sum b/go.sum index 11d5405..e69de29 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +0,0 @@ -go4.org v0.0.0-20180417224846-9599cf28b011 h1:i0QTVNl3j6yciHiQIHxz+mnsSQqo/xi78EGN7yNpMVw= -go4.org v0.0.0-20180417224846-9599cf28b011/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= diff --git a/io.go b/io.go new file mode 100644 index 0000000..67a1eea --- /dev/null +++ b/io.go @@ -0,0 +1,123 @@ +package zipserve + +import ( + "context" + "fmt" + "io" + "sort" +) + +// ReaderAt is like io.ReaderAt, but also takes context. +type ReaderAt interface { + // ReadAtContext has same semantics as ReadAt from io.ReaderAt, but takes context. + ReadAtContext(ctx context.Context, p []byte, off int64) (n int, err error) +} + +type sizeReaderAt interface { + io.ReaderAt + Size() int64 +} + +type offsetAndData struct { + offset int64 + data ReaderAt +} + +// multiReaderAt is a ReaderAt that joins multiple ReaderAt sequentially together. +type multiReaderAt struct { + parts []offsetAndData + size int64 +} + +// add a part to the multiContextReader. +// add can be used only before the reader is read from. +func (mcr *multiReaderAt) add(data ReaderAt, size int64) { + switch { + case size < 0: + panic(fmt.Sprintf("size cannot be negative: %v", size)) + case size == 0: + return + } + mcr.parts = append(mcr.parts, offsetAndData{ + offset: mcr.size, + data: data, + }) + mcr.size += size +} + +// addSizeReaderAt is like add, but takes sizeReaderAt +func (mcr *multiReaderAt) addSizeReaderAt(r sizeReaderAt) { + mcr.add(ignoreContext{r: r}, r.Size()) +} + +// endOffset is offset where the given part ends. +func (mcr *multiReaderAt) endOffset(partIndex int) int64 { + if partIndex == len(mcr.parts)-1 { + return mcr.size + } + return mcr.parts[partIndex+1].offset +} + +func (mcr *multiReaderAt) ReadAtContext(ctx context.Context, p []byte, off int64) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + if off >= mcr.size { + return 0, io.EOF + } + // find first part that has data for p + firstPartIndex := sort.Search(len(mcr.parts), func(i int) bool { + return mcr.endOffset(i) > off + }) + for partIndex := firstPartIndex; partIndex < len(mcr.parts) && len(p) > 0; partIndex++ { + if partIndex > firstPartIndex { + off = mcr.parts[partIndex].offset + } + partRemainingBytes := mcr.endOffset(partIndex) - off + sizeToRead := int64(len(p)) + if sizeToRead > partRemainingBytes { + sizeToRead = partRemainingBytes + } + n2, err2 := mcr.parts[partIndex].data.ReadAtContext(ctx, p[0:sizeToRead], off - mcr.parts[partIndex].offset) + n += n2 + if err2 != nil { + return n, err2 + } + p = p[n2:] + } + if len(p) > 0 { + // tried reading beyond size + return n, io.EOF + } + return n, nil +} + +func (mcr *multiReaderAt) ReadAt(p []byte, off int64) (n int, err error) { + return mcr.ReadAtContext(context.TODO(), p, off) +} + +func (mcr *multiReaderAt) Size() int64 { + return mcr.size +} + +// ignoreContext converts io.ReaderAt to ReaderAt +type ignoreContext struct { + r io.ReaderAt +} + +func (a ignoreContext) ReadAtContext(_ context.Context, p []byte, off int64) (n int, err error) { + return a.r.ReadAt(p, off) +} + +// withContext converts ReaderAt to io.ReaderAt. +// +// While usually we shouldn't store context in a structure, we ensure that withContext lives only within single +// request. +type withContext struct { + ctx context.Context + r ReaderAt +} + +func (w withContext) ReadAt(p []byte, off int64) (n int, err error) { + return w.r.ReadAtContext(w.ctx, p, off) +} diff --git a/io_test.go b/io_test.go new file mode 100644 index 0000000..78ae1ea --- /dev/null +++ b/io_test.go @@ -0,0 +1,244 @@ +package zipserve + +import ( + "bytes" + "context" + "errors" + "io" + "testing" +) + +type testCheckContext struct { + r io.ReaderAt + f func(ctx context.Context) +} +func (a testCheckContext) ReadAtContext(ctx context.Context, p []byte, off int64) (n int, err error) { + a.f(ctx) + return a.r.ReadAt(p, off) +} + +func TestMultiReaderAt_ReadAtContext(t *testing.T) { + tests := []struct { + name string + parts []string + offset int64 + size int64 + expectedResult string + expectedError string + }{ + { + name: "empty", + parts: nil, + offset: 0, + size: 0, + expectedResult: "", + }, + { + name: "empty size out of bounds", + parts: nil, + offset: 0, + size: 1, + expectedResult: "", + expectedError: "EOF", + }, + { + name: "empty offset out of bounds", + parts: nil, + offset: 1, + size: 1, + expectedResult: "", + expectedError: "EOF", + }, + { + name: "single part full", + parts: []string{"abcdefgh"}, + offset: 0, + size: 8, + expectedResult: "abcdefgh", + }, + { + name: "single part start", + parts: []string{"abcdefgh"}, + offset: 0, + size: 3, + expectedResult: "abc", + }, + { + name: "single part middle", + parts: []string{"abcdefgh"}, + offset: 3, + size: 3, + expectedResult: "def", + }, + { + name: "single part end", + parts: []string{"abcdefgh"}, + offset: 4, + size: 4, + expectedResult: "efgh", + }, + { + name: "single part size out of bounds", + parts: []string{"abcdefgh"}, + offset: 4, + size: 10, + expectedResult: "efgh", + expectedError: "EOF", + }, + { + name: "single part offset out of bounds", + parts: []string{"abcdefgh"}, + offset: 4, + size: 10, + expectedResult: "efgh", + expectedError: "EOF", + }, + { + name: "single part empty", + parts: []string{"abcdefgh"}, + offset: 0, + size: 0, + expectedResult: "", + }, + { + name: "multiple parts full", + parts: []string{"abcdefgh", "ijklm", "nopqrs"}, + offset: 0, + size: 19, + expectedResult: "abcdefghijklmnopqrs", + }, + { + name: "multiple parts beginning", + parts: []string{"abcdefgh", "ijklm", "nopqrs"}, + offset: 0, + size: 4, + expectedResult: "abcd", + }, + { + name: "multiple parts beginning 2", + parts: []string{"abcdefgh", "ijklm", "nopqrs"}, + offset: 0, + size: 10, + expectedResult: "abcdefghij", + }, + { + name: "multiple parts middle 1", + parts: []string{"abcdefgh", "ijklm", "nopqrs"}, + offset: 9, + size: 3, + expectedResult: "jkl", + }, + { + name: "multiple parts middle 2", + parts: []string{"abcdefgh", "ijklm", "nopqrs"}, + offset: 6, + size: 4, + expectedResult: "ghij", + }, + { + name: "multiple parts middle 3", + parts: []string{"abcdefgh", "ijklm", "nopqrs"}, + offset: 6, + size: 10, + expectedResult: "ghijklmnop", + }, + { + name: "multiple parts end", + parts: []string{"abcdefgh", "ijklm", "nopqrs"}, + offset: 6, + size: 13, + expectedResult: "ghijklmnopqrs", + }, + { + name: "multiple parts end 2", + parts: []string{"abcdefgh", "ijklm", "nopqrs"}, + offset: 15, + size: 4, + expectedResult: "pqrs", + }, + { + name: "multiple parts size out of bounds", + parts: []string{"abcdefgh", "ijklm", "nopqrs"}, + offset: 6, + size: 30, + expectedResult: "ghijklmnopqrs", + expectedError: "EOF", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + type testContextKey struct{} + ctx := context.WithValue(context.Background(), testContextKey{}, test.name) + + var mcr multiReaderAt + for i := range test.parts { + reader := testCheckContext{ + r: bytes.NewReader([]byte(test.parts[i])), + f: func(ctx context.Context) { + v := ctx.Value(testContextKey{}) + if v != test.name { + t.Logf("expected context value to be propagated, got %v", v) + t.Fail() + } + }, + } + mcr.add(reader, int64(len(test.parts[i]))) + } + p := make([]byte, test.size) + n, err := mcr.ReadAtContext(ctx, p, test.offset) + if n < 0 || n > len(p) { + t.Log("n out of bounds") + t.Fail() + } else { + result := string(p[:n]) + if test.expectedResult != result { + t.Logf("expected read %q, but got %q", test.expectedResult, result) + t.Fail() + } + if n < len(p) && err == nil { + t.Log("short read without error") + t.Fail() + } + } + if test.expectedError == "" { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } else { + switch { + case err == nil: + t.Fatalf("expected error %q, but got nil", test.expectedError) + case err.Error() != test.expectedError: + t.Fatalf("expected error %q, but got %q", test.expectedError, err.Error()) + } + } + }) + } +} + +type readWithError struct { + data []byte + err error +} + +func (r readWithError) ReadAtContext(ctx context.Context, p []byte, off int64) (n int, err error) { + return copy(p, r.data), r.err +} + +func TestMultiReaderAt_ReadAtContextError(t *testing.T) { + myError := errors.New("my error") + var mcr multiReaderAt + mcr.add(ignoreContext{r: bytes.NewReader([]byte("abc"))}, 3) + mcr.add(readWithError{data: []byte("def"), err: myError}, 10) + mcr.add(ignoreContext{r: bytes.NewReader([]byte("opqrst"))}, 6) + p := make([]byte, 10) + n, err := mcr.ReadAtContext(context.Background(), p, 1) + if n != 5 { + t.Logf("expected n=5, got %v", n) + t.Fail() + } + if !errors.Is(err, myError) { + t.Logf("expected err=%v, got %v", myError, err) + t.Fail() + } +} diff --git a/struct.go b/struct.go index 7c68ad3..ada9d42 100644 --- a/struct.go +++ b/struct.go @@ -115,6 +115,9 @@ type FileHeader struct { // the content must be compressed using the Method specified. // In case Store is used (the default), the compressed data is the same as // uncompressed data. + // + // Content may implement ReaderAt interface from this package, in that case + // Content's ReadAtContext method will be called instead of ReadAt. Content io.ReaderAt } diff --git a/zip_test.go b/zip_test.go index c22a19c..b2fadbf 100644 --- a/zip_test.go +++ b/zip_test.go @@ -12,7 +12,6 @@ import ( "encoding/binary" "errors" "fmt" - "go4.org/readerutil" "hash/crc32" "io" "io/ioutil" @@ -390,7 +389,7 @@ func suffixIsZip64(t *testing.T, zip sizedReaderAt) bool { return true } -func rleView(content func(w io.Writer) error) (readerutil.SizeReaderAt, error) { +func rleView(content func(w io.Writer) error) (sizeReaderAt, error) { buf := new(rleBuffer) err := content(buf) if err != nil { @@ -459,9 +458,10 @@ func TestZip64LargeDirectory(t *testing.T) { // sizeWithEnd returns a sized ReaderAt of size dots plus "END\n" appended func sizeWithEnd(size int64) sizedReaderAt { - return readerutil.NewMultiReaderAt( - io.NewSectionReader(&sameBytes{b: '.'}, 0, size), - bytes.NewReader([]byte("END\n"))) + var mcr multiReaderAt + mcr.addSizeReaderAt(io.NewSectionReader(&sameBytes{b: '.'}, 0, size)) + mcr.addSizeReaderAt(bytes.NewReader([]byte("END\n"))) + return &mcr } func testZip64(t testing.TB, size int64) sizedReaderAt {