diff --git a/fastcache.go b/fastcache.go index 8a3268e..2f43bc1 100644 --- a/fastcache.go +++ b/fastcache.go @@ -68,7 +68,13 @@ type Options struct { // Cache based on uri+querystring. IncludeQueryString bool + // Compression options. Compression CompressionsOptions + + // QueryArgsTransformerHook is a hook that can be used to transform the query string + // before it is used to generate the cache key. This can be used to selectively + // include / exclude query parameters from the cache key. + QueryArgsTransformerHook func(*fasthttp.Args) } // Item represents the cache entry for a single endpoint with the actual cache @@ -126,17 +132,10 @@ func (f *FastCache) Cached(h fastglue.FastRequestHandler, o *Options, group stri o.Compression.MinLength = 500 } - var hash [16]byte - // If IncludeQueryString option is set then cache based on uri + md5(query_string) - if o.IncludeQueryString { - hash = md5.Sum(r.RequestCtx.URI().FullURI()) - } else { - hash = md5.Sum(r.RequestCtx.URI().Path()) - } - uri := hex.EncodeToString(hash[:]) + hashedURI := hash(f.makeURI(r, o)) // Fetch etag + cached bytes from the store. - blob, err := f.s.Get(namespace, group, uri) + blob, err := f.s.Get(namespace, group, hashedURI) if err != nil { o.Logger.Printf("error reading cache: %v", err) } @@ -237,7 +236,7 @@ func (f *FastCache) ClearGroup(h fastglue.FastRequestHandler, o *Options, groups // Del deletes the cache for a single URI in a namespace->group. func (f *FastCache) Del(namespace, group, uri string) error { - return f.s.Del(namespace, group, uri) + return f.s.Del(namespace, group, hash(uri)) } // DelGroup deletes all cached URIs under a group. @@ -245,6 +244,48 @@ func (f *FastCache) DelGroup(namespace string, group ...string) error { return f.s.DelGroup(namespace, group...) } +// hash returns the md5 hash of a string. +func hash(b string) string { + var hash [16]byte = md5.Sum([]byte(b)) + + return hex.EncodeToString(hash[:]) +} + +// makeURI returns the URI to be used as the cache key. +func (f *FastCache) makeURI(r *fastglue.Request, o *Options) string { + // lexicographically sort the query string. + r.RequestCtx.QueryArgs().Sort(func(x, y []byte) int { + return bytes.Compare(x, y) + }) + + // If IncludeQueryString option is set then cache based on uri + md5(query_string) + if o.IncludeQueryString { + id := r.RequestCtx.URI().FullURI() + + // Check if we need to include only specific query params. + if o.QueryArgsTransformerHook != nil { + // Acquire a copy so as to not modify the request. + uriRaw := fasthttp.AcquireURI() + r.RequestCtx.URI().CopyTo(uriRaw) + + q := uriRaw.QueryArgs() + + // Call the hook to transform the query string. + o.QueryArgsTransformerHook(q) + + // Get the new URI. + id = uriRaw.FullURI() + + // Release the borrowed URI. + fasthttp.ReleaseURI(uriRaw) + } + + return string(id) + } + + return string(r.RequestCtx.URI().Path()) +} + // cache caches a response body. func (f *FastCache) cache(r *fastglue.Request, namespace, group string, o *Options) error { // ETag?. @@ -258,14 +299,7 @@ func (f *FastCache) cache(r *fastglue.Request, namespace, group string, o *Optio } // Write cache to the store (etag, content type, response body). - var hash [16]byte - // If IncludeQueryString option is set then cache based on uri + md5(query_string) - if o.IncludeQueryString { - hash = md5.Sum(r.RequestCtx.URI().FullURI()) - } else { - hash = md5.Sum(r.RequestCtx.URI().Path()) - } - uri := hex.EncodeToString(hash[:]) + hashedURI := hash(f.makeURI(r, o)) var blob []byte if !o.NoBlob { @@ -289,7 +323,7 @@ func (f *FastCache) cache(r *fastglue.Request, namespace, group string, o *Optio } } - err := f.s.Put(namespace, group, uri, item, o.TTL) + err := f.s.Put(namespace, group, hashedURI, item, o.TTL) if err != nil { return fmt.Errorf("error writing cache to store: %v", err) } diff --git a/fastcache_test.go b/fastcache_test.go index d2d2248..8d166d2 100644 --- a/fastcache_test.go +++ b/fastcache_test.go @@ -3,6 +3,7 @@ package fastcache_test import ( "bytes" "compress/gzip" + "fmt" "io" "log" "net/http" @@ -28,6 +29,8 @@ const ( var ( srv = fastglue.NewGlue() + fc *fastcache.FastCache + content = []byte("this is the reasonbly long test content that may be compressed") ) @@ -70,11 +73,54 @@ func init() { Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile), } - fc = fastcache.New(cachestore.New("CACHE:", redis.NewClient(&redis.Options{ - Addr: rd.Addr(), - }))) + includeQS = &fastcache.Options{ + NamespaceKey: namespaceKey, + ETag: true, + TTL: time.Second * 5, + Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile), + IncludeQueryString: true, + } + + includeQSNoEtag = &fastcache.Options{ + NamespaceKey: namespaceKey, + ETag: false, + TTL: time.Second * 5, + Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile), + IncludeQueryString: true, + } + + includeQSSpecific = &fastcache.Options{ + NamespaceKey: namespaceKey, + ETag: true, + TTL: time.Second * 5, + Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile), + IncludeQueryString: true, + QueryArgsTransformerHook: func(args *fasthttp.Args) { + // Copy the keys to delete, and delete them later. This is to + // avoid borking the VisitAll() iterator. + mp := map[string]struct{}{ + "foo": {}, + } + + delKeys := [][]byte{} + args.VisitAll(func(k, v []byte) { + if _, ok := mp[string(k)]; !ok { + delKeys = append(delKeys, k) + } + }) + + // Delete the keys. + for _, k := range delKeys { + args.DelBytes(k) + } + }, + } ) + fc = fastcache.New(cachestore.New("CACHE:", redis.NewClient(&redis.Options{ + Addr: rd.Addr(), + }))) + // Handlers. srv.Before(func(r *fastglue.Request) *fastglue.Request { r.RequestCtx.SetUserValue(namespaceKey, "test") @@ -102,6 +148,19 @@ func init() { return r.SendBytes(200, "text/plain", content) }, cfgDefault, group)) + srv.GET("/include-qs", fc.Cached(func(r *fastglue.Request) error { + return r.SendBytes(200, "text/plain", content) + }, includeQS, group)) + + srv.GET("/include-qs-no-etag", fc.Cached(func(r *fastglue.Request) error { + out := time.Now() + return r.SendBytes(200, "text/plain", []byte(fmt.Sprintf("%v", out))) + }, includeQSNoEtag, group)) + + srv.GET("/include-qs-specific", fc.Cached(func(r *fastglue.Request) error { + return r.SendBytes(200, "text/plain", content) + }, includeQSSpecific, group)) + // Start the server go func() { s := &fasthttp.Server{ @@ -177,6 +236,7 @@ func TestCache(t *testing.T) { if r.StatusCode != 200 { t.Fatalf("expected 200 but got %v", r.StatusCode) } + r, b = getReq(srvRoot+"/cached", r.Header.Get("Etag"), false, t) if r.StatusCode != 200 { t.Fatalf("expected 200 but got '%v'", r.StatusCode) @@ -213,6 +273,137 @@ func TestCache(t *testing.T) { } } +func TestCacheDelete(t *testing.T) { + // First request should be 200. + r, b := getReq(srvRoot+"/cached", "", false, t) + if r.StatusCode != 200 { + t.Fatalf("expected 200 but got %v", r.StatusCode) + } + if !bytes.Equal(b, content) { + t.Fatalf("expected 'ok' in body but got %v", b) + } + + // Second should be 304. + r, b = getReq(srvRoot+"/cached", r.Header.Get("Etag"), false, t) + if r.StatusCode != 304 { + t.Fatalf("expected 304 but got '%v'", r.StatusCode) + } + if !bytes.Equal(b, []byte("")) { + t.Fatalf("expected empty cached body but got '%v'", b) + } + + err := fc.Del(namespaceKey, group, "/cached") + if err != nil { + t.Fatalf("expected nil but got %v", err) + } + + // New request should be 200, since the cache is deleted. + r, b = getReq(srvRoot+"/cached", "", false, t) + if r.StatusCode != 200 { + t.Fatalf("expected 200 but got %v", r.StatusCode) + } + if !bytes.Equal(b, content) { + t.Fatalf("expected 'ok' in body but got %v", b) + } +} + +func TestQueryString(t *testing.T) { + // First request should be 200. + r, b := getReq(srvRoot+"/include-qs?foo=bar", "", false, t) + if r.StatusCode != 200 { + t.Fatalf("expected 200 but got %v", r.StatusCode) + } + + if !bytes.Equal(b, content) { + t.Fatalf("expected 'ok' in body but got %v", b) + } + + // Second should be 304. + r, _ = getReq(srvRoot+"/include-qs?foo=bar", r.Header.Get("Etag"), false, t) + if r.StatusCode != 304 { + t.Fatalf("expected 304 but got '%v'", r.StatusCode) + } +} + +func TestQueryStringLexicographical(t *testing.T) { + // First request should be 200. + r, b := getReq(srvRoot+"/include-qs?foo=bar&baz=qux", "", false, t) + if r.StatusCode != 200 { + t.Fatalf("expected 200 but got %v", r.StatusCode) + } + + if !bytes.Equal(b, content) { + t.Fatalf("expected 'ok' in body but got %v", b) + } + + // Second should be 304. + r, _ = getReq(srvRoot+"/include-qs?baz=qux&foo=bar", r.Header.Get("Etag"), false, t) + if r.StatusCode != 304 { + t.Fatalf("expected 304 but got '%v'", r.StatusCode) + } +} + +func TestQueryStringWithoutEtag(t *testing.T) { + // First request should be 200. + r, b := getReq(srvRoot+"/include-qs-no-etag?foo=bar", "", false, t) + if r.StatusCode != 200 { + t.Fatalf("expected 200 but got %v", r.StatusCode) + } + + // Second should be 200 but with same response. + r2, b2 := getReq(srvRoot+"/include-qs-no-etag?foo=bar", "", false, t) + if r2.StatusCode != 200 { + t.Fatalf("expected 200 but got '%v'", r2.StatusCode) + } + + if !bytes.Equal(b, b2) { + t.Fatalf("expected '%v' in body but got %v", b, b2) + } + + // Third should be 200 but with different response. + r3, b3 := getReq(srvRoot+"/include-qs-no-etag?foo=baz", "", false, t) + if r3.StatusCode != 200 { + t.Fatalf("expected 200 but got '%v'", r3.StatusCode) + } + + // time should be different + if bytes.Equal(b, b3) { + t.Fatalf("expected both to be different (should not be %v), but got %v", b, b3) + } +} + +func TestQueryStringSpecific(t *testing.T) { + // First request should be 200. + r1, b := getReq(srvRoot+"/include-qs-specific?foo=bar&baz=qux", "", false, t) + if r1.StatusCode != 200 { + t.Fatalf("expected 200 but got %v", r1.StatusCode) + } + if !bytes.Equal(b, content) { + t.Fatalf("expected 'ok' in body but got %v", b) + } + + // Second should be 304. + r, _ := getReq(srvRoot+"/include-qs-specific?foo=bar&baz=qux", r1.Header.Get("Etag"), false, t) + if r.StatusCode != 304 { + t.Fatalf("expected 304 but got '%v'", r.StatusCode) + } + + // Third should be 304 as foo=bar + r, _ = getReq(srvRoot+"/include-qs-specific?loo=mar&foo=bar&baz=qux&quux=quuz", r1.Header.Get("Etag"), false, t) + if r.StatusCode != 304 { + t.Fatalf("expected 304 but got '%v'", r.StatusCode) + } + + // Fourth should be 200 as foo=rab + r, b = getReq(srvRoot+"/include-qs-specific?foo=rab&baz=qux&quux=quuz", r1.Header.Get("Etag"), false, t) + if r.StatusCode != 200 { + t.Fatalf("expected 200 but got '%v'", r.StatusCode) + } + if !bytes.Equal(b, content) { + t.Fatalf("expected 'ok' in body but got %v", b) + } +} + func TestNoCache(t *testing.T) { // All requests should return 200. for n := 0; n < 3; n++ { diff --git a/stores/goredis/redis.go b/stores/goredis/redis.go index 5da0da2..baff0cc 100644 --- a/stores/goredis/redis.go +++ b/stores/goredis/redis.go @@ -57,6 +57,7 @@ func (s *Store) Get(namespace, group, uri string) (fastcache.Item, error) { var ( out fastcache.Item ) + // Get content_type, etag, blob in that order. cmd := s.cn.HMGet(s.ctx, s.key(namespace, group), s.field(keyCtype, uri), s.field(keyEtag, uri), s.field(keyCompression, uri), s.field(keyBlob, uri)) if err := cmd.Err(); err != nil {