diff --git a/client.go b/client.go index 2b30c44..64ce440 100644 --- a/client.go +++ b/client.go @@ -110,7 +110,7 @@ func (c *Client) Start(octx context.Context) (Controller, error) { traceSink = proxy } - incoming := make(chan ctrlRequest, c.numWorkers) + incoming := make(chan any, c.numWorkers) outgoing := make(chan Resource, c.numWorkers) syncoutgoing := make(chan synchronousRequest, c.numWorkers) wg.Add(c.numWorkers) diff --git a/controller.go b/controller.go index 4cc1965..f8ff6ae 100644 --- a/controller.go +++ b/controller.go @@ -10,6 +10,7 @@ import ( type Controller interface { AddResource(Resource) error + Lookup(string) (Resource, error) RemoveResource(string) error Refresh(string) error ShutdownContext(context.Context) error @@ -20,7 +21,7 @@ type controller struct { cancel context.CancelFunc check *time.Ticker // incoming accepts new control requests from external sources - incoming chan ctrlRequest + incoming chan any // outgoing sends Syncer objects to the worker pool outgoing chan Resource @@ -60,18 +61,39 @@ func (c *controller) ShutdownContext(ctx context.Context) error { } } -const ( - addResource = iota - rmResource - refreshResource -) - -type ctrlRequest struct { - op int - reply chan error +type ctrlRequest[T any] struct { + reply chan T resource Resource u string } +type lookupReply struct { + r Resource + err error +} + +type addRequest ctrlRequest[error] +type rmRequest ctrlRequest[error] +type refreshRequest ctrlRequest[error] +type lookupRequest ctrlRequest[lookupReply] + +// Lookup returns a resource by its URL. If the resource does not exist, it +// will return an error. +// +// Unfortunately, due to the way typed parameters are handled in Go, we can only +// return a Resource object (and not a ResourceBase[T] object). This means that +// you will either need to use the `Resource.Get()` method or use a type +// assertion to obtain a `ResourceBase[T]` to get to the actual object you are +// looking for +func (c *controller) Lookup(u string) (Resource, error) { + // to avoid having to acquire locks, we do this asynchronously + reply := make(chan lookupReply, 1) + c.incoming <- lookupRequest{ + reply: reply, + u: u, + } + r := <-reply + return r.r, r.err +} // AddResource adds a new resource to the controller. If the resource already // exists, it will return an error. @@ -81,8 +103,7 @@ func (c *controller) AddResource(r Resource) error { } reply := make(chan error, 1) - c.incoming <- ctrlRequest{ - op: addResource, + c.incoming <- addRequest{ reply: reply, resource: r, } @@ -93,8 +114,7 @@ func (c *controller) AddResource(r Resource) error { // not exist, it will return an error. func (c *controller) RemoveResource(u string) error { reply := make(chan error, 1) - c.incoming <- ctrlRequest{ - op: rmResource, + c.incoming <- rmRequest{ reply: reply, u: u, } @@ -107,17 +127,16 @@ func (c *controller) RemoveResource(u string) error { // This function is synchronous, and will block until the resource has been refreshed. func (c *controller) Refresh(u string) error { reply := make(chan error, 1) - c.incoming <- ctrlRequest{ - op: refreshResource, + c.incoming <- refreshRequest{ reply: reply, u: u, } return <-reply } -func (c *controller) handleRequest(ctx context.Context, req ctrlRequest) { - switch req.op { - case addResource: +func (c *controller) handleRequest(ctx context.Context, req any) { + switch req := req.(type) { + case addRequest: r := req.resource for _, item := range c.items { if item.URL() == r.URL() { @@ -128,7 +147,7 @@ func (c *controller) handleRequest(ctx context.Context, req ctrlRequest) { } c.items = append(c.items, r) - sendReply(ctx, req.reply, nil) + closeReply(req.reply) // force the next check to happen immediately if d := r.ConstantInterval(); d > 0 { @@ -138,7 +157,7 @@ func (c *controller) handleRequest(ctx context.Context, req ctrlRequest) { } c.check.Reset(time.Nanosecond) - case rmResource: + case rmRequest: u := req.u minInterval := oneDay loc := -1 @@ -158,9 +177,9 @@ func (c *controller) handleRequest(ctx context.Context, req ctrlRequest) { } c.items = slices.Delete(c.items, loc, loc+1) - sendReply(ctx, req.reply, nil) + closeReply[error](req.reply) c.check.Reset(minInterval) - case refreshResource: + case refreshRequest: u := req.u for _, item := range c.items { if item.URL() != u { @@ -174,6 +193,15 @@ func (c *controller) handleRequest(ctx context.Context, req ctrlRequest) { return } sendReply(ctx, req.reply, errResourceNotFound) + case lookupRequest: + u := req.u + for _, item := range c.items { + if item.URL() == u { + sendReply(ctx, req.reply, lookupReply{r: item}) + return + } + } + sendReply(ctx, req.reply, lookupReply{err: errResourceNotFound}) } } @@ -193,15 +221,15 @@ func sendWorkerSynchronous(ctx context.Context, ch chan synchronousRequest, r sy } } -func sendReply(ctx context.Context, ch chan error, err error) { - defer close(ch) - if err == nil { - return - } +func closeReply[T any](ch chan T) { + close(ch) +} +func sendReply[T any](ctx context.Context, ch chan T, v T) { + defer closeReply[T](ch) select { case <-ctx.Done(): - case ch <- err: + case ch <- v: } } diff --git a/httprc_test.go b/httprc_test.go index 2822d97..d295807 100644 --- a/httprc_test.go +++ b/httprc_test.go @@ -116,6 +116,14 @@ func TestClient(t *testing.T) { require.Equal(t, tc.Expected, dst, `r.Get should return expected value`) }) } + + for _, tc := range testcases { + t.Run("Lookup "+tc.URL, func(t *testing.T) { + r, err := ctrl.Lookup(tc.URL) + require.NoError(t, err, `ctrl.Lookup should succeed`) + require.Equal(t, tc.URL, r.URL(), `r.URL should return expected value`) + }) + } } func TestRefresh(t *testing.T) {