diff --git a/packages/rate-limiter/.eslintrc.json b/packages/core/.eslintrc.json similarity index 95% rename from packages/rate-limiter/.eslintrc.json rename to packages/core/.eslintrc.json index a6175d5..e281c9c 100644 --- a/packages/rate-limiter/.eslintrc.json +++ b/packages/core/.eslintrc.json @@ -1,30 +1,30 @@ -{ - "extends": ["../../.eslintrc.base.json"], - "ignorePatterns": ["!**/*"], - "overrides": [ - { - "files": ["*.ts", "*.tsx", "*.js", "*.jsx"], - "rules": {} - }, - { - "files": ["*.ts", "*.tsx"], - "rules": {} - }, - { - "files": ["*.js", "*.jsx"], - "rules": {} - }, - { - "files": ["*.json"], - "parser": "jsonc-eslint-parser", - "rules": { - "@nx/dependency-checks": [ - "error", - { - "ignoredFiles": ["{projectRoot}/vite.config.{js,ts,mjs,mts}"] - } - ] - } - } - ] -} +{ + "extends": ["../../.eslintrc.base.json"], + "ignorePatterns": ["!**/*"], + "overrides": [ + { + "files": ["*.ts", "*.tsx", "*.js", "*.jsx"], + "rules": {} + }, + { + "files": ["*.ts", "*.tsx"], + "rules": {} + }, + { + "files": ["*.js", "*.jsx"], + "rules": {} + }, + { + "files": ["*.json"], + "parser": "jsonc-eslint-parser", + "rules": { + "@nx/dependency-checks": [ + "error", + { + "ignoredFiles": ["{projectRoot}/vite.config.{js,ts,mjs,mts}"] + } + ] + } + } + ] +} diff --git a/packages/rate-limiter/.swcrc b/packages/core/.swcrc similarity index 100% rename from packages/rate-limiter/.swcrc rename to packages/core/.swcrc diff --git a/packages/rate-limiter/package.json b/packages/core/package.json similarity index 94% rename from packages/rate-limiter/package.json rename to packages/core/package.json index af1f0cf..2bb6830 100644 --- a/packages/rate-limiter/package.json +++ b/packages/core/package.json @@ -1,13 +1,13 @@ -{ - "name": "hono-rate-limiter", - "version": "0.0.1", - "dependencies": { - "@swc/helpers": "~0.5.2" - }, - "peerDependencies": { - "hono": "^4.1.1" - }, - "type": "commonjs", - "main": "./src/index.js", - "typings": "./src/index.d.ts" -} +{ + "name": "hono-rate-limiter", + "version": "0.0.1", + "dependencies": { + "@swc/helpers": "~0.5.2" + }, + "peerDependencies": { + "hono": "^4.1.1" + }, + "type": "commonjs", + "main": "./src/index.js", + "typings": "./src/index.d.ts" +} diff --git a/packages/rate-limiter/project.json b/packages/core/project.json similarity index 72% rename from packages/rate-limiter/project.json rename to packages/core/project.json index af9eb1b..de0bdfa 100644 --- a/packages/rate-limiter/project.json +++ b/packages/core/project.json @@ -1,53 +1,53 @@ -{ - "name": "rate-limiter", - "$schema": "../../node_modules/nx/schemas/project-schema.json", - "sourceRoot": "packages/rate-limiter/src", - "projectType": "library", - "targets": { - "build": { - "executor": "@nx/js:swc", - "outputs": ["{options.outputPath}"], - "options": { - "outputPath": "dist/packages/rate-limiter", - "main": "packages/rate-limiter/src/index.ts", - "tsConfig": "packages/rate-limiter/tsconfig.lib.json", - "assets": ["README.md"] - } - }, - "nx-release-publish": { - "options": { - "packageRoot": "dist\\{projectRoot}" - } - }, - "test": { - "executor": "@nx/vite:test", - "outputs": ["{options.reportsDirectory}"], - "options": { - "reportsDirectory": "../../coverage/packages/rate-limiter" - } - }, - "version": { - "executor": "@jscutlery/semver:version", - "options": { - "preset": "conventional" - } - }, - "deploy": { - "executor": "ngx-deploy-npm:deploy", - "options": { - "access": "public", - "distFolderPath": "dist/packages/rate-limiter" - }, - "dependsOn": ["build"] - } - }, - "tags": [], - "release": { - "version": { - "generatorOptions": { - "packageRoot": "dist\\{projectRoot}", - "currentVersionResolver": "git-tag" - } - } - } -} +{ + "name": "core", + "$schema": "../../node_modules/nx/schemas/project-schema.json", + "sourceRoot": "packages/core/src", + "projectType": "library", + "targets": { + "build": { + "executor": "@nx/js:swc", + "outputs": ["{options.outputPath}"], + "options": { + "outputPath": "dist/packages/core", + "main": "packages/core/src/index.ts", + "tsConfig": "packages/core/tsconfig.lib.json", + "assets": ["README.md"] + } + }, + "nx-release-publish": { + "options": { + "packageRoot": "dist\\{projectRoot}" + } + }, + "test": { + "executor": "@nx/vite:test", + "outputs": ["{options.reportsDirectory}"], + "options": { + "reportsDirectory": "../../coverage/packages/core" + } + }, + "version": { + "executor": "@jscutlery/semver:version", + "options": { + "preset": "conventional" + } + }, + "deploy": { + "executor": "ngx-deploy-npm:deploy", + "options": { + "access": "public", + "distFolderPath": "dist/packages/core" + }, + "dependsOn": ["build"] + } + }, + "tags": [], + "release": { + "version": { + "generatorOptions": { + "packageRoot": "dist\\{projectRoot}", + "currentVersionResolver": "git-tag" + } + } + } +} diff --git a/packages/rate-limiter/src/core/__tests__/headers.spec.ts b/packages/core/src/core/__tests__/headers.spec.ts similarity index 100% rename from packages/rate-limiter/src/core/__tests__/headers.spec.ts rename to packages/core/src/core/__tests__/headers.spec.ts diff --git a/packages/rate-limiter/src/core/__tests__/helpers/create-server.ts b/packages/core/src/core/__tests__/helpers/create-server.ts similarity index 87% rename from packages/rate-limiter/src/core/__tests__/helpers/create-server.ts rename to packages/core/src/core/__tests__/helpers/create-server.ts index 7c41832..fc4b9b8 100644 --- a/packages/rate-limiter/src/core/__tests__/helpers/create-server.ts +++ b/packages/core/src/core/__tests__/helpers/create-server.ts @@ -9,7 +9,9 @@ export function createServer< I extends Input = NonNullable, >({ middleware, -}: { middleware: MiddlewareHandler | MiddlewareHandler[] }) { +}: { + middleware: MiddlewareHandler | MiddlewareHandler[]; +}) { const wares = Array.isArray(middleware) ? middleware : [middleware]; // Init the app diff --git a/packages/rate-limiter/src/core/__tests__/helpers/index.ts b/packages/core/src/core/__tests__/helpers/index.ts similarity index 100% rename from packages/rate-limiter/src/core/__tests__/helpers/index.ts rename to packages/core/src/core/__tests__/helpers/index.ts diff --git a/packages/rate-limiter/src/core/__tests__/middleware.spec.ts b/packages/core/src/core/__tests__/middleware.spec.ts similarity index 96% rename from packages/rate-limiter/src/core/__tests__/middleware.spec.ts rename to packages/core/src/core/__tests__/middleware.spec.ts index 1a9fbfc..0ccb993 100644 --- a/packages/rate-limiter/src/core/__tests__/middleware.spec.ts +++ b/packages/core/src/core/__tests__/middleware.spec.ts @@ -1,849 +1,849 @@ -// import { platform } from 'node:process' -import { createAdaptorServer } from "@hono/node-server"; -import { agent as request } from "supertest"; -import { rateLimiter } from ".."; -import type { - ClientRateLimitInfo, - ConfigType, - RateLimitInfo, - Store, -} from "../../types"; -import { createServer } from "./helpers"; - -describe.skip("middleware test", () => { - beforeEach(() => { - vi.useFakeTimers(); - }); - afterEach(() => { - vi.useRealTimers(); - vi.restoreAllMocks(); - }); - - class MockStore implements Store { - initWasCalled = false; - incrementWasCalled = false; - decrementWasCalled = false; - resetKeyWasCalled = false; - getWasCalled = false; - resetAllWasCalled = false; - - counter = 0; - - init(_options: ConfigType): void { - this.initWasCalled = true; - } - - async get(_key: string): Promise { - this.getWasCalled = true; - - return { totalHits: this.counter, resetTime: undefined }; - } - - async increment(_key: string): Promise { - this.counter += 1; - this.incrementWasCalled = true; - - return { totalHits: this.counter, resetTime: undefined }; - } - - async decrement(_key: string): Promise { - this.counter -= 1; - this.decrementWasCalled = true; - } - - async resetKey(_key: string): Promise { - this.resetKeyWasCalled = true; - } - - async resetAll(): Promise { - this.resetAllWasCalled = true; - } - } - - it("should not modify the options object passed", () => { - const options = {}; - rateLimiter(options); - expect(options).toStrictEqual({}); - }); - - it("should call `init` even if no requests have come in", async () => { - const store = new MockStore(); - rateLimiter({ - store, - }); - - expect(store.initWasCalled).toEqual(true); - }); - - it("should let the first request through", async () => { - const app = createAdaptorServer( - createServer({ middleware: rateLimiter({ limit: 1 }) }), - ); - - await request(app).get("/").expect(200).expect("Hi there!"); - }); - - it("should refuse additional connections once IP has reached the max", async () => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - limit: 2, - }), - }), - ); - - await request(app).get("/").expect(200); - await request(app).get("/").expect(200); - await request(app).get("/").expect(429); - }); - - it("should (eventually) accept new connections from a blocked IP", async () => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - limit: 2, - windowMs: 50, - }), - }), - ); - - await request(app).get("/").expect(200); - await request(app).get("/").expect(200); - await request(app).get("/").expect(429); - vi.advanceTimersByTime(60); - await request(app).get("/").expect(200); - }); - - it("should work repeatedly", async () => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - limit: 2, - windowMs: 50, - }), - }), - ); - - await request(app).get("/").expect(200); - await request(app).get("/").expect(200); - await request(app).get("/").expect(429); - vi.advanceTimersByTime(60); - await request(app).get("/").expect(200); - await request(app).get("/").expect(200); - await request(app).get("/").expect(429); - vi.advanceTimersByTime(60); - await request(app).get("/").expect(200); - }); - - it("should block all requests if max is set to 0", async () => { - const app = createAdaptorServer( - createServer({ middleware: rateLimiter({ limit: 0 }) }), - ); - - await request(app).get("/").expect(429); - }); - - it("should show the provided message instead of the default message when max connections are reached", async () => { - const message = "Enhance your calm"; - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - windowMs: 1000, - limit: 2, - message, - }), - }), - ); - - await request(app).get("/").expect(200); - await request(app).get("/").expect(200); - await request(app).get("/").expect(429).expect(message); - }); - - it("should allow the error status code to be customized", async () => { - const statusCode = 420; - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - limit: 1, - // @ts-expect-error - statusCode, - }), - }), - ); - await request(app).get("/").expect(200); - await request(app).get("/").expect(statusCode); - }); - - it.skip("should allow responding with a JSON message", async () => { - const message = { - error: { - code: "too-many-requests", - message: "Too many requests were attempted in a short span of time.", - }, - }; - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - message, - limit: 1, - }), - }), - ); - - await request(app).get("/").expect(200, "Hi there!"); - await request(app).get("/").expect(429, message); - }); - - it("should allow message to be a function", async () => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - message: () => "Too many requests.", - limit: 1, - }), - }), - ); - - await request(app).get("/").expect(200, "Hi there!"); - await request(app).get("/").expect(429, "Too many requests."); - }); - - it("should allow message to be a function that returns a promise", async () => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - message: async () => "Too many requests.", - limit: 1, - }), - }), - ); - - await request(app).get("/").expect(200, "Hi there!"); - await request(app).get("/").expect(429, "Too many requests."); - }); - - it("should use a custom handler when specified", async () => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - limit: 1, - handler(c) { - // @ts-expect-error - c.status(420); - return c.text("Enhance your calm"); - }, - }), - }), - ); - - await request(app).get("/").expect(200); - await request(app).get("/").expect(420, "Enhance your calm"); - }); - - it("should allow custom key generators", async () => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - limit: 2, - keyGenerator: (c) => c.req.query("key") ?? "", - }), - }), - ); - - await request(app).get("/").query({ key: 1 }).expect(200); - await request(app).get("/").query({ key: 1 }).expect(200); - - await request(app).get("/").query({ key: 2 }).expect(200); - - await request(app).get("/").query({ key: 1 }).expect(429); - - await request(app).get("/").query({ key: 2 }).expect(200); - await request(app).get("/").query({ key: 2 }).expect(429); - }); - - it("should allow custom skip function", async () => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - limit: 2, - skip: () => true, - }), - }), - ); - - await request(app).get("/").query({ key: 1 }).expect(200); - await request(app).get("/").query({ key: 1 }).expect(200); - - await request(app).get("/").query({ key: 1 }).expect(200); - }); - - it("should allow custom skip function that returns a promise", async () => { - const limiter = rateLimiter({ - limit: 2, - skip: async () => true, - }); - - const app = createAdaptorServer(createServer({ middleware: limiter })); - await request(app).get("/").query({ key: 1 }).expect(200); - await request(app).get("/").query({ key: 1 }).expect(200); - - await request(app).get("/").query({ key: 1 }).expect(200); - }); - - it("should allow max to be a function", async () => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - limit: () => 2, - }), - }), - ); - - await request(app).get("/").expect(200); - await request(app).get("/").expect(200); - await request(app).get("/").expect(429); - }); - - it("should allow max to be a function that returns a promise", async () => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - limit: async () => 2, - }), - }), - ); - - await request(app).get("/").expect(200); - await request(app).get("/").expect(200); - await request(app).get("/").expect(429); - }); - - it("should calculate the remaining hits", async () => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - limit: async () => 2, - }), - }), - ); - - await request(app) - .get("/") - .expect(200) - .expect("x-ratelimit-limit", "2") - .expect("x-ratelimit-remaining", "1") - .expect((response) => { - if ("retry-after" in response.headers) { - throw new Error( - `Expected no retry-after header, got ${ - response.headers["retry-after"] as string - }`, - ); - } - }) - .expect(200, "Hi there!"); - }); - - it.each([["modern", new MockStore()]])( - "should call `increment` on the store (%s store)", - async (name, store) => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - store, - }), - }), - ); - await request(app).get("/"); - - expect(store.incrementWasCalled).toEqual(true); - }, - ); - - it.skip.each([["modern", new MockStore()]])( - "should call `resetKey` on the store (%s store)", - async (name, store) => { - const limiter = rateLimiter({ - store, - }); - - limiter.resetKey("key"); - - expect(store.resetKeyWasCalled).toEqual(true); - }, - ); - - it.skip.each([["modern", new MockStore()]])( - "should call `get` on the store (%s store)", - async (name, store) => { - const limiter = rateLimiter({ - store, - }); - - const response = await limiter.getKey("key"); - - expect(store.getWasCalled).toEqual(true); - expect(typeof response?.totalHits).toBe("number"); - }, - ); - - it.each([["modern", new MockStore()]])( - "should decrement hits when requests succeed and `skipSuccessfulRequests` is set to true (%s store)", - async (name, store) => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - skipSuccessfulRequests: true, - store, - }), - }), - ); - - await request(app).get("/").expect(200); - - expect(store.decrementWasCalled).toEqual(true); - }, - ); - - it.each([["modern", new MockStore()]])( - "should not decrement hits when requests fail and `skipSuccessfulRequests` is set to true (%s store)", - async (name, store) => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - skipSuccessfulRequests: true, - store, - }), - }), - ); - - await request(app).get("/error").expect(400); - - expect(store.decrementWasCalled).toEqual(false); - }, - ); - - it.each([["modern", new MockStore()]])( - "should decrement hits when requests succeed, `skipSuccessfulRequests` is set to true and a custom `requestWasSuccessful` method used (%s store)", - async (name, store) => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - skipSuccessfulRequests: true, - requestWasSuccessful: (c) => c.res.status === 200, - store, - }), - }), - ); - - await request(app).get("/").expect(200); - expect(store.decrementWasCalled).toEqual(true); - }, - ); - - it.each([["modern", new MockStore()]])( - "should not decrement hits when requests fail, `skipSuccessfulRequests` is set to true and a custom `requestWasSuccessful` method used (%s store)", - async (name, store) => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - skipSuccessfulRequests: true, - requestWasSuccessful(c) { - return c.res.status === 200; - }, - store, - }), - }), - ); - - await request(app).get("/error").expect(400); - - expect(store.decrementWasCalled).toEqual(false); - }, - ); - - it.each([["modern", new MockStore()]])( - "should decrement hits when requests succeed, `skipSuccessfulRequests` is set to true and a custom `requestWasSuccessful` method used (%s store)", - async (name, store) => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - skipSuccessfulRequests: true, - requestWasSuccessful: (c) => c.req.query("success") === "1", - store, - }), - }), - ); - - await request(app).get("/?success=1"); - - expect(store.decrementWasCalled).toEqual(true); - }, - ); - - it.each([["modern", new MockStore()]])( - "should not decrement hits when requests fail, `skipSuccessfulRequests` is set to true and a custom `requestWasSuccessful` method used (%s store)", - async (name, store) => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - skipSuccessfulRequests: true, - requestWasSuccessful: (c) => c.req.query("success") === "1", - store, - }), - }), - ); - - await request(app).get("/?success=0"); - - expect(store.decrementWasCalled).toEqual(false); - }, - ); - - it.each([["modern", new MockStore()]])( - "should decrement hits when requests fail and `skipFailedRequests` is set to true (%s store)", - async (name, store) => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - skipFailedRequests: true, - store, - }), - }), - ); - - await request(app).get("/error").expect(400); - - expect(store.decrementWasCalled).toEqual(true); - }, - ); - - it.each([["modern", new MockStore()]])( - "should not decrement hits when requests succeed and `skipFailedRequests` is set to true (%s store)", - async (name, store) => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - skipFailedRequests: true, - store, - }), - }), - ); - - await request(app).get("/").expect(200); - - expect(store.decrementWasCalled).toEqual(false); - }, - ); - - it.each([["modern", new MockStore()]])( - "should decrement hits when requests fail, `skipFailedRequests` is set to true and a custom `requestWasSuccessful` method used that returns a promise (%s store)", - async (name, store) => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - skipFailedRequests: true, - requestWasSuccessful: async () => false, - store, - }), - }), - ); - - await request(app).get("/").expect(200); - expect(store.decrementWasCalled).toEqual(true); - }, - ); - - // FIXME: This test times out _sometimes_ on MacOS and Windows, so it is disabled for now - /* - ;(platform === 'darwin' ? it.skip : it).each([ - ['modern', new MockStore()], - ['legacy', new MockLegacyStore()], - ['compat', new MockBackwardCompatibleStore()], - ])( - 'should decrement hits when response closes and `skipFailedRequests` is set to true (%s store)', - async (name, store) => { - vi.useRealTimers() - vi.setTimeout(60_000) - - const app = createAdaptorServer(createServer( - rateLimiter({ - skipFailedRequests: true, - store, - }), - ) - - let _resolve: () => void - const connectionClosed = new Promise((resolve) => { - _resolve = resolve - }) - - app.get('/hang-server', (_request, response) => { - response.on('close', _resolve) - }) - - const hangRequest = request(app).get('/hang-server').timeout(10) - - await expect(hangRequest).rejects.toThrow() - await connectionClosed - - expect(store.decrementWasCalled).toEqual(true) - }, - ) - */ - - it.each([["modern", new MockStore()]])( - "should decrement hits when response emits an error and `skipFailedRequests` is set to true (%s store)", - async (name, store) => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - skipFailedRequests: true, - store, - }), - }), - ); - - await request(app).get("/crash"); - - expect(store.decrementWasCalled).toEqual(true); - }, - ); - - it.each([["modern", new MockStore()]])( - "should decrement hits when rate limit is reached and `skipFailedRequests` is set to true (%s store)", - async (name, store) => { - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - limit: 2, - store, - skipFailedRequests: true, - }), - }), - ); - - await request(app).get("/").expect(200); - await request(app).get("/").expect(200); - await request(app).get("/").expect(429); - - expect(store.decrementWasCalled).toEqual(true); - }, - ); - - it.each([["modern", new MockStore()]])( - "should forward errors in the handler using `next()` (%s store)", - async (name, store) => { - let errorCaught = false; - - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - limit: 1, - store, - handler() { - const exception = new Error("420: Enhance your calm"); - throw exception; - }, - }), - }).onError((error, c) => { - errorCaught = true; - return c.text(error.message, 500); - }), - ); - - await request(app).get("/").expect(200); - await request(app).get("/").expect(500); - - expect(errorCaught).toEqual(true); - }, - ); - - it.each([["modern", new MockStore()]])( - "should forward errors in `skip()` using `next()` (%s store)", - async (name, store) => { - let errorCaught = false; - - const app = createAdaptorServer( - createServer({ - middleware: rateLimiter({ - limit: 1, - store, - skip() { - const exception = new Error("420: Enhance your calm"); - throw exception; - }, - }), - }).onError((error, c) => { - errorCaught = true; - return c.text(error.message, 500); - }), - ); - - await request(app).get("/").expect(500); - - expect(errorCaught).toEqual(true); - }, - ); - - it.skip("should pass the number of hits and the limit to the next request handler in the `request.rateLimiter` property", async () => { - let savedRequestObject: RateLimitInfo | undefined; - - const app = createAdaptorServer( - createServer<{ Variables: { rateLimit: RateLimitInfo } }>({ - middleware: [ - async (c, next) => { - savedRequestObject = c.get("rateLimit"); - await next(); - }, - rateLimiter(), - ], - }), - ); - - await request(app).get("/").expect(200); - expect(savedRequestObject).toMatchObject({ - limit: 5, - used: 1, - remaining: 4, - resetTime: expect.any(Date), - }); - - // Make sure the hidden proerty is also set. - expect(savedRequestObject?.current).toBe(1); - - savedRequestObject = undefined; - await request(app).get("/").expect(200); - expect(savedRequestObject).toMatchObject({ - limit: 5, - used: 2, - remaining: 3, - resetTime: expect.any(Date), - }); - expect(savedRequestObject?.current).toBe(2); - }); - - it.skip("should pass the number of hits and the limit to the next request handler with a custom property", async () => { - let savedRequestObject: RateLimitInfo | undefined; - - const app = createAdaptorServer( - createServer<{ Variables: { rateLimit: RateLimitInfo } }>({ - middleware: [ - async (c, next) => { - savedRequestObject = c.get("rateLimit"); - await next(); - }, - rateLimiter({ - requestPropertyName: "rateLimitInfo", - }), - ], - }), - ); - - await request(app).get("/").expect(200); - expect(savedRequestObject).toMatchObject({ - limit: 5, - used: 1, - remaining: 4, - resetTime: expect.any(Date), - }); - expect(savedRequestObject?.current).toBe(1); - - savedRequestObject = undefined; - await request(app).get("/").expect(200); - expect(savedRequestObject).toMatchObject({ - limit: 5, - used: 2, - remaining: 3, - resetTime: expect.any(Date), - }); - expect(savedRequestObject?.current).toBe(2); - }); - - it.skip("should handle two rate-limiters with different `requestPropertyNames` operating independently", async () => { - const keyLimiter = rateLimiter({ - limit: 2, - requestPropertyName: "rateLimitKey", - keyGenerator: (c) => c.req.query("key") ?? "", - handler(c) { - // @ts-expect-error - c.status(420); - return c.text("Enhance your calm"); - }, - }); - const globalLimiter = rateLimiter({ - limit: 5, - requestPropertyName: "rateLimitGlobal", - keyGenerator: () => "global", - handler(c) { - c.status(429); - return c.text("Too many requests"); - }, - }); - - let savedRequestObject: RateLimitInfo; - - const app = createAdaptorServer( - createServer<{ - Variables: { - rateLimit: RateLimitInfo; - rateLimitKey: RateLimitInfo; - rateLimitGlobal: RateLimitInfo; - }; - }>({ - middleware: [ - async (c, next) => { - savedRequestObject = c.get("rateLimit"); - await next(); - }, - keyLimiter, - globalLimiter, - ], - }), - ); - - await request(app).get("/").query({ key: 1 }).expect(200); - expect(savedRequestObject).toBeTruthy(); - expect(savedRequestObject.rateLimiter).toBeUndefined(); - - expect(savedRequestObject.rateLimitKey).toBeTruthy(); - expect(savedRequestObject.rateLimitKey.limit).toEqual(2); - expect(savedRequestObject.rateLimitKey.remaining).toEqual(1); - - expect(savedRequestObject.rateLimitGlobal).toBeTruthy(); - expect(savedRequestObject.rateLimitGlobal.limit).toEqual(5); - expect(savedRequestObject.rateLimitGlobal.remaining).toEqual(4); - - savedRequestObject = undefined; - await request(app).get("/").query({ key: 2 }).expect(200); - expect(savedRequestObject.rateLimitKey.remaining).toEqual(1); - expect(savedRequestObject.rateLimitGlobal.remaining).toEqual(3); - - savedRequestObject = undefined; - await request(app).get("/").query({ key: 1 }).expect(200); - expect(savedRequestObject.rateLimitKey.remaining).toEqual(0); - expect(savedRequestObject.rateLimitGlobal.remaining).toEqual(2); - - savedRequestObject = undefined; - await request(app).get("/").query({ key: 2 }).expect(200); - expect(savedRequestObject.rateLimitKey.remaining).toEqual(0); - expect(savedRequestObject.rateLimitGlobal.remaining).toEqual(1); - - savedRequestObject = undefined; - await request(app) - .get("/") - .query({ key: 1 }) - .expect(420, "Enhance your calm"); - expect(savedRequestObject.rateLimitKey.remaining).toEqual(0); - - savedRequestObject = undefined; - await request(app).get("/").query({ key: 3 }).expect(200); - await request(app) - .get("/") - .query({ key: 3 }) - .expect(429, "Too many requests"); - expect(savedRequestObject.rateLimitKey.remaining).toEqual(0); - expect(savedRequestObject.rateLimitGlobal.remaining).toEqual(0); - }); -}); +// import { platform } from 'node:process' +import { createAdaptorServer } from "@hono/node-server"; +import { agent as request } from "supertest"; +import { rateLimiter } from ".."; +import type { + ClientRateLimitInfo, + ConfigType, + RateLimitInfo, + Store, +} from "../../types"; +import { createServer } from "./helpers"; + +describe.skip("middleware test", () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + afterEach(() => { + vi.useRealTimers(); + vi.restoreAllMocks(); + }); + + class MockStore implements Store { + initWasCalled = false; + incrementWasCalled = false; + decrementWasCalled = false; + resetKeyWasCalled = false; + getWasCalled = false; + resetAllWasCalled = false; + + counter = 0; + + init(_options: ConfigType): void { + this.initWasCalled = true; + } + + async get(_key: string): Promise { + this.getWasCalled = true; + + return { totalHits: this.counter, resetTime: undefined }; + } + + async increment(_key: string): Promise { + this.counter += 1; + this.incrementWasCalled = true; + + return { totalHits: this.counter, resetTime: undefined }; + } + + async decrement(_key: string): Promise { + this.counter -= 1; + this.decrementWasCalled = true; + } + + async resetKey(_key: string): Promise { + this.resetKeyWasCalled = true; + } + + async resetAll(): Promise { + this.resetAllWasCalled = true; + } + } + + it("should not modify the options object passed", () => { + const options = {}; + rateLimiter(options); + expect(options).toStrictEqual({}); + }); + + it("should call `init` even if no requests have come in", async () => { + const store = new MockStore(); + rateLimiter({ + store, + }); + + expect(store.initWasCalled).toEqual(true); + }); + + it("should let the first request through", async () => { + const app = createAdaptorServer( + createServer({ middleware: rateLimiter({ limit: 1 }) }), + ); + + await request(app).get("/").expect(200).expect("Hi there!"); + }); + + it("should refuse additional connections once IP has reached the max", async () => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + limit: 2, + }), + }), + ); + + await request(app).get("/").expect(200); + await request(app).get("/").expect(200); + await request(app).get("/").expect(429); + }); + + it("should (eventually) accept new connections from a blocked IP", async () => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + limit: 2, + windowMs: 50, + }), + }), + ); + + await request(app).get("/").expect(200); + await request(app).get("/").expect(200); + await request(app).get("/").expect(429); + vi.advanceTimersByTime(60); + await request(app).get("/").expect(200); + }); + + it("should work repeatedly", async () => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + limit: 2, + windowMs: 50, + }), + }), + ); + + await request(app).get("/").expect(200); + await request(app).get("/").expect(200); + await request(app).get("/").expect(429); + vi.advanceTimersByTime(60); + await request(app).get("/").expect(200); + await request(app).get("/").expect(200); + await request(app).get("/").expect(429); + vi.advanceTimersByTime(60); + await request(app).get("/").expect(200); + }); + + it("should block all requests if max is set to 0", async () => { + const app = createAdaptorServer( + createServer({ middleware: rateLimiter({ limit: 0 }) }), + ); + + await request(app).get("/").expect(429); + }); + + it("should show the provided message instead of the default message when max connections are reached", async () => { + const message = "Enhance your calm"; + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + windowMs: 1000, + limit: 2, + message, + }), + }), + ); + + await request(app).get("/").expect(200); + await request(app).get("/").expect(200); + await request(app).get("/").expect(429).expect(message); + }); + + it("should allow the error status code to be customized", async () => { + const statusCode = 420; + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + limit: 1, + // @ts-expect-error + statusCode, + }), + }), + ); + await request(app).get("/").expect(200); + await request(app).get("/").expect(statusCode); + }); + + it.skip("should allow responding with a JSON message", async () => { + const message = { + error: { + code: "too-many-requests", + message: "Too many requests were attempted in a short span of time.", + }, + }; + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + message, + limit: 1, + }), + }), + ); + + await request(app).get("/").expect(200, "Hi there!"); + await request(app).get("/").expect(429, message); + }); + + it("should allow message to be a function", async () => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + message: () => "Too many requests.", + limit: 1, + }), + }), + ); + + await request(app).get("/").expect(200, "Hi there!"); + await request(app).get("/").expect(429, "Too many requests."); + }); + + it("should allow message to be a function that returns a promise", async () => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + message: async () => "Too many requests.", + limit: 1, + }), + }), + ); + + await request(app).get("/").expect(200, "Hi there!"); + await request(app).get("/").expect(429, "Too many requests."); + }); + + it("should use a custom handler when specified", async () => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + limit: 1, + handler(c) { + // @ts-expect-error + c.status(420); + return c.text("Enhance your calm"); + }, + }), + }), + ); + + await request(app).get("/").expect(200); + await request(app).get("/").expect(420, "Enhance your calm"); + }); + + it("should allow custom key generators", async () => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + limit: 2, + keyGenerator: (c) => c.req.query("key") ?? "", + }), + }), + ); + + await request(app).get("/").query({ key: 1 }).expect(200); + await request(app).get("/").query({ key: 1 }).expect(200); + + await request(app).get("/").query({ key: 2 }).expect(200); + + await request(app).get("/").query({ key: 1 }).expect(429); + + await request(app).get("/").query({ key: 2 }).expect(200); + await request(app).get("/").query({ key: 2 }).expect(429); + }); + + it("should allow custom skip function", async () => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + limit: 2, + skip: () => true, + }), + }), + ); + + await request(app).get("/").query({ key: 1 }).expect(200); + await request(app).get("/").query({ key: 1 }).expect(200); + + await request(app).get("/").query({ key: 1 }).expect(200); + }); + + it("should allow custom skip function that returns a promise", async () => { + const limiter = rateLimiter({ + limit: 2, + skip: async () => true, + }); + + const app = createAdaptorServer(createServer({ middleware: limiter })); + await request(app).get("/").query({ key: 1 }).expect(200); + await request(app).get("/").query({ key: 1 }).expect(200); + + await request(app).get("/").query({ key: 1 }).expect(200); + }); + + it("should allow max to be a function", async () => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + limit: () => 2, + }), + }), + ); + + await request(app).get("/").expect(200); + await request(app).get("/").expect(200); + await request(app).get("/").expect(429); + }); + + it("should allow max to be a function that returns a promise", async () => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + limit: async () => 2, + }), + }), + ); + + await request(app).get("/").expect(200); + await request(app).get("/").expect(200); + await request(app).get("/").expect(429); + }); + + it("should calculate the remaining hits", async () => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + limit: async () => 2, + }), + }), + ); + + await request(app) + .get("/") + .expect(200) + .expect("x-ratelimit-limit", "2") + .expect("x-ratelimit-remaining", "1") + .expect((response) => { + if ("retry-after" in response.headers) { + throw new Error( + `Expected no retry-after header, got ${ + response.headers["retry-after"] as string + }`, + ); + } + }) + .expect(200, "Hi there!"); + }); + + it.each([["modern", new MockStore()]])( + "should call `increment` on the store (%s store)", + async (name, store) => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + store, + }), + }), + ); + await request(app).get("/"); + + expect(store.incrementWasCalled).toEqual(true); + }, + ); + + it.skip.each([["modern", new MockStore()]])( + "should call `resetKey` on the store (%s store)", + async (name, store) => { + const limiter = rateLimiter({ + store, + }); + + limiter.resetKey("key"); + + expect(store.resetKeyWasCalled).toEqual(true); + }, + ); + + it.skip.each([["modern", new MockStore()]])( + "should call `get` on the store (%s store)", + async (name, store) => { + const limiter = rateLimiter({ + store, + }); + + const response = await limiter.getKey("key"); + + expect(store.getWasCalled).toEqual(true); + expect(typeof response?.totalHits).toBe("number"); + }, + ); + + it.each([["modern", new MockStore()]])( + "should decrement hits when requests succeed and `skipSuccessfulRequests` is set to true (%s store)", + async (name, store) => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + skipSuccessfulRequests: true, + store, + }), + }), + ); + + await request(app).get("/").expect(200); + + expect(store.decrementWasCalled).toEqual(true); + }, + ); + + it.each([["modern", new MockStore()]])( + "should not decrement hits when requests fail and `skipSuccessfulRequests` is set to true (%s store)", + async (name, store) => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + skipSuccessfulRequests: true, + store, + }), + }), + ); + + await request(app).get("/error").expect(400); + + expect(store.decrementWasCalled).toEqual(false); + }, + ); + + it.each([["modern", new MockStore()]])( + "should decrement hits when requests succeed, `skipSuccessfulRequests` is set to true and a custom `requestWasSuccessful` method used (%s store)", + async (name, store) => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + skipSuccessfulRequests: true, + requestWasSuccessful: (c) => c.res.status === 200, + store, + }), + }), + ); + + await request(app).get("/").expect(200); + expect(store.decrementWasCalled).toEqual(true); + }, + ); + + it.each([["modern", new MockStore()]])( + "should not decrement hits when requests fail, `skipSuccessfulRequests` is set to true and a custom `requestWasSuccessful` method used (%s store)", + async (name, store) => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + skipSuccessfulRequests: true, + requestWasSuccessful(c) { + return c.res.status === 200; + }, + store, + }), + }), + ); + + await request(app).get("/error").expect(400); + + expect(store.decrementWasCalled).toEqual(false); + }, + ); + + it.each([["modern", new MockStore()]])( + "should decrement hits when requests succeed, `skipSuccessfulRequests` is set to true and a custom `requestWasSuccessful` method used (%s store)", + async (name, store) => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + skipSuccessfulRequests: true, + requestWasSuccessful: (c) => c.req.query("success") === "1", + store, + }), + }), + ); + + await request(app).get("/?success=1"); + + expect(store.decrementWasCalled).toEqual(true); + }, + ); + + it.each([["modern", new MockStore()]])( + "should not decrement hits when requests fail, `skipSuccessfulRequests` is set to true and a custom `requestWasSuccessful` method used (%s store)", + async (name, store) => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + skipSuccessfulRequests: true, + requestWasSuccessful: (c) => c.req.query("success") === "1", + store, + }), + }), + ); + + await request(app).get("/?success=0"); + + expect(store.decrementWasCalled).toEqual(false); + }, + ); + + it.each([["modern", new MockStore()]])( + "should decrement hits when requests fail and `skipFailedRequests` is set to true (%s store)", + async (name, store) => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + skipFailedRequests: true, + store, + }), + }), + ); + + await request(app).get("/error").expect(400); + + expect(store.decrementWasCalled).toEqual(true); + }, + ); + + it.each([["modern", new MockStore()]])( + "should not decrement hits when requests succeed and `skipFailedRequests` is set to true (%s store)", + async (name, store) => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + skipFailedRequests: true, + store, + }), + }), + ); + + await request(app).get("/").expect(200); + + expect(store.decrementWasCalled).toEqual(false); + }, + ); + + it.each([["modern", new MockStore()]])( + "should decrement hits when requests fail, `skipFailedRequests` is set to true and a custom `requestWasSuccessful` method used that returns a promise (%s store)", + async (name, store) => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + skipFailedRequests: true, + requestWasSuccessful: async () => false, + store, + }), + }), + ); + + await request(app).get("/").expect(200); + expect(store.decrementWasCalled).toEqual(true); + }, + ); + + // FIXME: This test times out _sometimes_ on MacOS and Windows, so it is disabled for now + /* + ;(platform === 'darwin' ? it.skip : it).each([ + ['modern', new MockStore()], + ['legacy', new MockLegacyStore()], + ['compat', new MockBackwardCompatibleStore()], + ])( + 'should decrement hits when response closes and `skipFailedRequests` is set to true (%s store)', + async (name, store) => { + vi.useRealTimers() + vi.setTimeout(60_000) + + const app = createAdaptorServer(createServer( + rateLimiter({ + skipFailedRequests: true, + store, + }), + ) + + let _resolve: () => void + const connectionClosed = new Promise((resolve) => { + _resolve = resolve + }) + + app.get('/hang-server', (_request, response) => { + response.on('close', _resolve) + }) + + const hangRequest = request(app).get('/hang-server').timeout(10) + + await expect(hangRequest).rejects.toThrow() + await connectionClosed + + expect(store.decrementWasCalled).toEqual(true) + }, + ) + */ + + it.each([["modern", new MockStore()]])( + "should decrement hits when response emits an error and `skipFailedRequests` is set to true (%s store)", + async (name, store) => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + skipFailedRequests: true, + store, + }), + }), + ); + + await request(app).get("/crash"); + + expect(store.decrementWasCalled).toEqual(true); + }, + ); + + it.each([["modern", new MockStore()]])( + "should decrement hits when rate limit is reached and `skipFailedRequests` is set to true (%s store)", + async (name, store) => { + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + limit: 2, + store, + skipFailedRequests: true, + }), + }), + ); + + await request(app).get("/").expect(200); + await request(app).get("/").expect(200); + await request(app).get("/").expect(429); + + expect(store.decrementWasCalled).toEqual(true); + }, + ); + + it.each([["modern", new MockStore()]])( + "should forward errors in the handler using `next()` (%s store)", + async (name, store) => { + let errorCaught = false; + + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + limit: 1, + store, + handler() { + const exception = new Error("420: Enhance your calm"); + throw exception; + }, + }), + }).onError((error, c) => { + errorCaught = true; + return c.text(error.message, 500); + }), + ); + + await request(app).get("/").expect(200); + await request(app).get("/").expect(500); + + expect(errorCaught).toEqual(true); + }, + ); + + it.each([["modern", new MockStore()]])( + "should forward errors in `skip()` using `next()` (%s store)", + async (name, store) => { + let errorCaught = false; + + const app = createAdaptorServer( + createServer({ + middleware: rateLimiter({ + limit: 1, + store, + skip() { + const exception = new Error("420: Enhance your calm"); + throw exception; + }, + }), + }).onError((error, c) => { + errorCaught = true; + return c.text(error.message, 500); + }), + ); + + await request(app).get("/").expect(500); + + expect(errorCaught).toEqual(true); + }, + ); + + it.skip("should pass the number of hits and the limit to the next request handler in the `request.rateLimiter` property", async () => { + let savedRequestObject: RateLimitInfo | undefined; + + const app = createAdaptorServer( + createServer<{ Variables: { rateLimit: RateLimitInfo } }>({ + middleware: [ + async (c, next) => { + savedRequestObject = c.get("rateLimit"); + await next(); + }, + rateLimiter(), + ], + }), + ); + + await request(app).get("/").expect(200); + expect(savedRequestObject).toMatchObject({ + limit: 5, + used: 1, + remaining: 4, + resetTime: expect.any(Date), + }); + + // Make sure the hidden proerty is also set. + expect(savedRequestObject?.current).toBe(1); + + savedRequestObject = undefined; + await request(app).get("/").expect(200); + expect(savedRequestObject).toMatchObject({ + limit: 5, + used: 2, + remaining: 3, + resetTime: expect.any(Date), + }); + expect(savedRequestObject?.current).toBe(2); + }); + + it.skip("should pass the number of hits and the limit to the next request handler with a custom property", async () => { + let savedRequestObject: RateLimitInfo | undefined; + + const app = createAdaptorServer( + createServer<{ Variables: { rateLimit: RateLimitInfo } }>({ + middleware: [ + async (c, next) => { + savedRequestObject = c.get("rateLimit"); + await next(); + }, + rateLimiter({ + requestPropertyName: "rateLimitInfo", + }), + ], + }), + ); + + await request(app).get("/").expect(200); + expect(savedRequestObject).toMatchObject({ + limit: 5, + used: 1, + remaining: 4, + resetTime: expect.any(Date), + }); + expect(savedRequestObject?.current).toBe(1); + + savedRequestObject = undefined; + await request(app).get("/").expect(200); + expect(savedRequestObject).toMatchObject({ + limit: 5, + used: 2, + remaining: 3, + resetTime: expect.any(Date), + }); + expect(savedRequestObject?.current).toBe(2); + }); + + it.skip("should handle two rate-limiters with different `requestPropertyNames` operating independently", async () => { + const keyLimiter = rateLimiter({ + limit: 2, + requestPropertyName: "rateLimitKey", + keyGenerator: (c) => c.req.query("key") ?? "", + handler(c) { + // @ts-expect-error + c.status(420); + return c.text("Enhance your calm"); + }, + }); + const globalLimiter = rateLimiter({ + limit: 5, + requestPropertyName: "rateLimitGlobal", + keyGenerator: () => "global", + handler(c) { + c.status(429); + return c.text("Too many requests"); + }, + }); + + let savedRequestObject: RateLimitInfo; + + const app = createAdaptorServer( + createServer<{ + Variables: { + rateLimit: RateLimitInfo; + rateLimitKey: RateLimitInfo; + rateLimitGlobal: RateLimitInfo; + }; + }>({ + middleware: [ + async (c, next) => { + savedRequestObject = c.get("rateLimit"); + await next(); + }, + keyLimiter, + globalLimiter, + ], + }), + ); + + await request(app).get("/").query({ key: 1 }).expect(200); + expect(savedRequestObject).toBeTruthy(); + expect(savedRequestObject.rateLimiter).toBeUndefined(); + + expect(savedRequestObject.rateLimitKey).toBeTruthy(); + expect(savedRequestObject.rateLimitKey.limit).toEqual(2); + expect(savedRequestObject.rateLimitKey.remaining).toEqual(1); + + expect(savedRequestObject.rateLimitGlobal).toBeTruthy(); + expect(savedRequestObject.rateLimitGlobal.limit).toEqual(5); + expect(savedRequestObject.rateLimitGlobal.remaining).toEqual(4); + + savedRequestObject = undefined; + await request(app).get("/").query({ key: 2 }).expect(200); + expect(savedRequestObject.rateLimitKey.remaining).toEqual(1); + expect(savedRequestObject.rateLimitGlobal.remaining).toEqual(3); + + savedRequestObject = undefined; + await request(app).get("/").query({ key: 1 }).expect(200); + expect(savedRequestObject.rateLimitKey.remaining).toEqual(0); + expect(savedRequestObject.rateLimitGlobal.remaining).toEqual(2); + + savedRequestObject = undefined; + await request(app).get("/").query({ key: 2 }).expect(200); + expect(savedRequestObject.rateLimitKey.remaining).toEqual(0); + expect(savedRequestObject.rateLimitGlobal.remaining).toEqual(1); + + savedRequestObject = undefined; + await request(app) + .get("/") + .query({ key: 1 }) + .expect(420, "Enhance your calm"); + expect(savedRequestObject.rateLimitKey.remaining).toEqual(0); + + savedRequestObject = undefined; + await request(app).get("/").query({ key: 3 }).expect(200); + await request(app) + .get("/") + .query({ key: 3 }) + .expect(429, "Too many requests"); + expect(savedRequestObject.rateLimitKey.remaining).toEqual(0); + expect(savedRequestObject.rateLimitGlobal.remaining).toEqual(0); + }); +}); diff --git a/packages/rate-limiter/src/core/__tests__/options.spec.ts b/packages/core/src/core/__tests__/options.spec.ts similarity index 100% rename from packages/rate-limiter/src/core/__tests__/options.spec.ts rename to packages/core/src/core/__tests__/options.spec.ts diff --git a/packages/rate-limiter/src/core/headers.ts b/packages/core/src/core/headers.ts similarity index 97% rename from packages/rate-limiter/src/core/headers.ts rename to packages/core/src/core/headers.ts index f582128..1d45bde 100644 --- a/packages/rate-limiter/src/core/headers.ts +++ b/packages/core/src/core/headers.ts @@ -1,96 +1,96 @@ -import type { Context } from "hono"; -import type { RateLimitInfo } from "../types"; - -/** - * Returns the number of seconds left for the window to reset. Uses `windowMs` - * in case the store doesn't return a `resetTime`. - * - * @param resetTime {Date | undefined} - The timestamp at which the store window resets. - * @param windowMs {number | undefined} - The window length. - */ -const getResetSeconds = ( - resetTime?: Date, - windowMs?: number, -): number | undefined => { - let resetSeconds: number | undefined; - if (resetTime) { - const deltaSeconds = Math.ceil((resetTime.getTime() - Date.now()) / 1000); - resetSeconds = Math.max(0, deltaSeconds); - } else if (windowMs) { - // This isn't really correct, but the field is required by the spec in `draft-7`, - // so this is the best we can do. The validator should have already logged a - // warning by this point. - resetSeconds = Math.ceil(windowMs / 1000); - } - - return resetSeconds; -}; - -/** - * Sets `RateLimit-*`` headers based on the sixth draft of the IETF specification. - * See https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-ratelimit-headers-06. - * - * @param context {Context} - The hono context object to set headers on. - * @param info {RateLimitInfo} - The rate limit info, used to set the headers. - * @param windowMs {number} - The window length. - */ -export const setDraft6Headers = ( - context: Context, - info: RateLimitInfo, - windowMs: number, -): void => { - if (context.finalized) return; - - const windowSeconds = Math.ceil(windowMs / 1000); - const resetSeconds = getResetSeconds(info.resetTime); - - context.header("RateLimit-Policy", `${info.limit};w=${windowSeconds}`); - context.header("RateLimit-Limit", info.limit.toString()); - context.header("RateLimit-Remaining", info.remaining.toString()); - - // Set this header only if the store returns a `resetTime`. - if (resetSeconds) context.header("RateLimit-Reset", resetSeconds.toString()); -}; - -/** - * Sets `RateLimit` & `RateLimit-Policy` headers based on the seventh draft of the spec. - * See https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-ratelimit-headers-07. - * - * @param context {Context} - The hono context object to set headers on. - * @param info {RateLimitInfo} - The rate limit info, used to set the headers. - * @param windowMs {number} - The window length. - */ -export const setDraft7Headers = ( - context: Context, - info: RateLimitInfo, - windowMs: number, -): void => { - if (context.finalized) return; - - const windowSeconds = Math.ceil(windowMs / 1000); - const resetSeconds = getResetSeconds(info.resetTime, windowMs); - - context.header("RateLimit-Policy", `${info.limit};w=${windowSeconds}`); - context.header( - "RateLimit", - `limit=${info.limit}, remaining=${info.remaining}, reset=${resetSeconds}`, - ); -}; - -/** - * Sets the `Retry-After` header. - * - * @param context {Context} - The hono context object to set headers on. - * @param info {RateLimitInfo} - The rate limit info, used to set the headers. - * @param windowMs {number} - The window length. - */ -export const setRetryAfterHeader = ( - context: Context, - info: RateLimitInfo, - windowMs: number, -): void => { - if (context.finalized) return; - - const resetSeconds = getResetSeconds(info.resetTime, windowMs); - context.header("Retry-After", resetSeconds?.toString()); -}; +import type { Context } from "hono"; +import type { RateLimitInfo } from "../types"; + +/** + * Returns the number of seconds left for the window to reset. Uses `windowMs` + * in case the store doesn't return a `resetTime`. + * + * @param resetTime {Date | undefined} - The timestamp at which the store window resets. + * @param windowMs {number | undefined} - The window length. + */ +const getResetSeconds = ( + resetTime?: Date, + windowMs?: number, +): number | undefined => { + let resetSeconds: number | undefined; + if (resetTime) { + const deltaSeconds = Math.ceil((resetTime.getTime() - Date.now()) / 1000); + resetSeconds = Math.max(0, deltaSeconds); + } else if (windowMs) { + // This isn't really correct, but the field is required by the spec in `draft-7`, + // so this is the best we can do. The validator should have already logged a + // warning by this point. + resetSeconds = Math.ceil(windowMs / 1000); + } + + return resetSeconds; +}; + +/** + * Sets `RateLimit-*`` headers based on the sixth draft of the IETF specification. + * See https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-ratelimit-headers-06. + * + * @param context {Context} - The hono context object to set headers on. + * @param info {RateLimitInfo} - The rate limit info, used to set the headers. + * @param windowMs {number} - The window length. + */ +export const setDraft6Headers = ( + context: Context, + info: RateLimitInfo, + windowMs: number, +): void => { + if (context.finalized) return; + + const windowSeconds = Math.ceil(windowMs / 1000); + const resetSeconds = getResetSeconds(info.resetTime); + + context.header("RateLimit-Policy", `${info.limit};w=${windowSeconds}`); + context.header("RateLimit-Limit", info.limit.toString()); + context.header("RateLimit-Remaining", info.remaining.toString()); + + // Set this header only if the store returns a `resetTime`. + if (resetSeconds) context.header("RateLimit-Reset", resetSeconds.toString()); +}; + +/** + * Sets `RateLimit` & `RateLimit-Policy` headers based on the seventh draft of the spec. + * See https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-ratelimit-headers-07. + * + * @param context {Context} - The hono context object to set headers on. + * @param info {RateLimitInfo} - The rate limit info, used to set the headers. + * @param windowMs {number} - The window length. + */ +export const setDraft7Headers = ( + context: Context, + info: RateLimitInfo, + windowMs: number, +): void => { + if (context.finalized) return; + + const windowSeconds = Math.ceil(windowMs / 1000); + const resetSeconds = getResetSeconds(info.resetTime, windowMs); + + context.header("RateLimit-Policy", `${info.limit};w=${windowSeconds}`); + context.header( + "RateLimit", + `limit=${info.limit}, remaining=${info.remaining}, reset=${resetSeconds}`, + ); +}; + +/** + * Sets the `Retry-After` header. + * + * @param context {Context} - The hono context object to set headers on. + * @param info {RateLimitInfo} - The rate limit info, used to set the headers. + * @param windowMs {number} - The window length. + */ +export const setRetryAfterHeader = ( + context: Context, + info: RateLimitInfo, + windowMs: number, +): void => { + if (context.finalized) return; + + const resetSeconds = getResetSeconds(info.resetTime, windowMs); + context.header("Retry-After", resetSeconds?.toString()); +}; diff --git a/packages/rate-limiter/src/core/index.ts b/packages/core/src/core/index.ts similarity index 100% rename from packages/rate-limiter/src/core/index.ts rename to packages/core/src/core/index.ts diff --git a/packages/rate-limiter/src/core/validations.ts b/packages/core/src/core/validations.ts similarity index 96% rename from packages/rate-limiter/src/core/validations.ts rename to packages/core/src/core/validations.ts index abaf0f1..39343b0 100644 --- a/packages/rate-limiter/src/core/validations.ts +++ b/packages/core/src/core/validations.ts @@ -1,4 +1,4 @@ -import type { Store } from "../types"; - -export const isValidStore = (value: Store): value is Store => - !!value?.increment; +import type { Store } from "../types"; + +export const isValidStore = (value: Store): value is Store => + !!value?.increment; diff --git a/packages/rate-limiter/src/index.ts b/packages/core/src/index.ts similarity index 100% rename from packages/rate-limiter/src/index.ts rename to packages/core/src/index.ts diff --git a/packages/rate-limiter/src/memcache/__tests__/store.spec.ts b/packages/core/src/memcache/__tests__/store.spec.ts similarity index 97% rename from packages/rate-limiter/src/memcache/__tests__/store.spec.ts rename to packages/core/src/memcache/__tests__/store.spec.ts index 44aa78f..aed240d 100644 --- a/packages/rate-limiter/src/memcache/__tests__/store.spec.ts +++ b/packages/core/src/memcache/__tests__/store.spec.ts @@ -1,181 +1,181 @@ -import MemoryStore from ".."; -import type { ConfigType } from "../../types"; - -const minute = 60 * 1000; - -describe("memory store test", () => { - beforeEach(() => { - vi.useFakeTimers(); - vi.spyOn(global, "clearInterval"); - }); - afterEach(() => { - vi.useRealTimers(); - vi.restoreAllMocks(); - }); - - it("returns the current hit count and reset time for a key", async () => { - const store = new MemoryStore(); - store.init({ windowMs: minute } as ConfigType); - const key = "test-store"; - - await store.increment(key); - - const response = await store.get(key); - expect(response).toMatchObject({ - totalHits: 1, - resetTime: expect.any(Date), - }); - }); - - it("sets the value to 1 on first call to `increment`", async () => { - const store = new MemoryStore(); - store.init({ windowMs: minute } as ConfigType); - const key = "test-store"; - - const { totalHits } = await store.increment(key); - expect(totalHits).toEqual(1); - }); - - it("increments the key for the store when `increment` is called", async () => { - const store = new MemoryStore(); - store.init({ windowMs: minute } as ConfigType); - const key = "test-store"; - - await store.increment(key); - - const { totalHits } = await store.increment(key); - expect(totalHits).toEqual(2); - }); - - it("decrements the key for the store when `decrement` is called", async () => { - const store = new MemoryStore(); - store.init({ windowMs: minute } as ConfigType); - const key = "test-store"; - - await store.increment(key); - await store.increment(key); - await store.decrement(key); - - const { totalHits } = await store.increment(key); - expect(totalHits).toEqual(2); - }); - - it("resets the count for a key in the store when `resetKey` is called", async () => { - const store = new MemoryStore(); - store.init({ windowMs: minute } as ConfigType); - const key = "test-store"; - - await store.increment(key); - await store.resetKey(key); - - const { totalHits } = await store.increment(key); - expect(totalHits).toEqual(1); - }); - - it("resets the count for all keys in the store when `resetAll` is called", async () => { - const store = new MemoryStore(); - store.init({ windowMs: minute } as ConfigType); - const keyOne = "test-store-one"; - const keyTwo = "test-store-two"; - - await store.increment(keyOne); - await store.increment(keyTwo); - await store.resetAll(); - - const { totalHits: totalHitsOne } = await store.increment(keyOne); - const { totalHits: totalHitsTwo } = await store.increment(keyTwo); - expect(totalHitsOne).toEqual(1); - expect(totalHitsTwo).toEqual(1); - }); - - it("clears the timer when `shutdown` is called", async () => { - const store = new MemoryStore(); - store.init({ windowMs: minute } as ConfigType); - expect(store.interval).toBeDefined(); - store.shutdown(); - expect(clearInterval).toHaveBeenCalledWith(store.interval); - }); - - it("resets the count for all the keys in the store when the timeout is reached", async () => { - const store = new MemoryStore(); - store.init({ windowMs: 50 } as ConfigType); - const keyOne = "test-store-one"; - const keyTwo = "test-store-two"; - - await store.increment(keyOne); - await store.increment(keyTwo); - - vi.advanceTimersByTime(60); - const { totalHits: totalHitsOne } = await store.increment(keyOne); - const { totalHits: totalHitsTwo } = await store.increment(keyTwo); - expect(totalHitsOne).toEqual(1); - expect(totalHitsTwo).toEqual(1); - }); - - it("can run in electron where setInterval does not return a Timeout object with an unset function", async () => { - const originalSetInterval = setInterval; - let timeoutId = 1; - let realTimeoutId: NodeJS.Timer; - // @ts-expect-error We want to not return a `Timer` instance for testing - vi.spyOn(global, "setTimeout").mockImplementation((callback, timeout) => { - realTimeoutId = originalSetInterval(callback, timeout); - return timeoutId++; - }); - - const store = new MemoryStore(); - store.init({ windowMs: -1 } as ConfigType); - const key = "test-store"; - - try { - const { totalHits } = await store.increment(key); - expect(totalHits).toEqual(1); - } finally { - // @ts-expect-error `realTimeoutId` is already set in the `spyOn` call - clearTimeout(realTimeoutId); - } - }); - - it("should move clients from previous to current", async () => { - const store = new MemoryStore(); - store.init({ windowMs: 100 } as ConfigType); - - await store.increment("key1"); - // Key1 is now in current - expect(store.current.has("key1")).toBe(true); - expect(store.previous.has("key1")).toBe(false); - - vi.advanceTimersByTime(100); - // Key1 is now in previous, current is empty - expect(store.current.has("key1")).toBe(false); - expect(store.previous.has("key1")).toBe(true); - - await store.increment("key1"); - // Should move key from previous to current - expect(store.current.has("key1")).toBe(true); - expect(store.previous.has("key1")).toBe(false); - }); - - // Covers the same bug as above, but in a more robust way that doesn't touch any internal structures - it("does not allow a Client object to be assigned to two keys", async () => { - const store = new MemoryStore(); - store.init({ windowMs: 100 } as ConfigType); - await store.increment("key1"); // Key1 is now in current - - vi.advanceTimersByTime(100); // Key1 is now in previous. Target pool size is 1, but it's empty. - await store.increment("key1"); // Key1 is now in current again. If it's also in previous, that's a bug! - await store.increment("key2"); // Need 1 new client to keep the pool size target at 1 - - vi.advanceTimersByTime(100); // Key1 and key2 are now in previous. Target pool size is 1, but it should be empty. - await store.increment("key1"); // Move it from previous to current - await store.increment("key1"); - let returnValue1 = await store.increment("key1"); - expect(returnValue1.totalHits).toBe(3); - - const returnValue3 = await store.increment("key3"); // Should create a new Client instance because the pool should be empty. In the bad case, it instead resets the same object to 1 - expect(returnValue1).not.toBe(returnValue3); // Should be separate objects - expect(returnValue3.totalHits).toBe(1); - - returnValue1 = await store.increment("key1"); - expect(returnValue1.totalHits).toBe(4); // Should be 4, will be 2 if there's a reuse bug - }); -}); +import MemoryStore from ".."; +import type { ConfigType } from "../../types"; + +const minute = 60 * 1000; + +describe("memory store test", () => { + beforeEach(() => { + vi.useFakeTimers(); + vi.spyOn(global, "clearInterval"); + }); + afterEach(() => { + vi.useRealTimers(); + vi.restoreAllMocks(); + }); + + it("returns the current hit count and reset time for a key", async () => { + const store = new MemoryStore(); + store.init({ windowMs: minute } as ConfigType); + const key = "test-store"; + + await store.increment(key); + + const response = await store.get(key); + expect(response).toMatchObject({ + totalHits: 1, + resetTime: expect.any(Date), + }); + }); + + it("sets the value to 1 on first call to `increment`", async () => { + const store = new MemoryStore(); + store.init({ windowMs: minute } as ConfigType); + const key = "test-store"; + + const { totalHits } = await store.increment(key); + expect(totalHits).toEqual(1); + }); + + it("increments the key for the store when `increment` is called", async () => { + const store = new MemoryStore(); + store.init({ windowMs: minute } as ConfigType); + const key = "test-store"; + + await store.increment(key); + + const { totalHits } = await store.increment(key); + expect(totalHits).toEqual(2); + }); + + it("decrements the key for the store when `decrement` is called", async () => { + const store = new MemoryStore(); + store.init({ windowMs: minute } as ConfigType); + const key = "test-store"; + + await store.increment(key); + await store.increment(key); + await store.decrement(key); + + const { totalHits } = await store.increment(key); + expect(totalHits).toEqual(2); + }); + + it("resets the count for a key in the store when `resetKey` is called", async () => { + const store = new MemoryStore(); + store.init({ windowMs: minute } as ConfigType); + const key = "test-store"; + + await store.increment(key); + await store.resetKey(key); + + const { totalHits } = await store.increment(key); + expect(totalHits).toEqual(1); + }); + + it("resets the count for all keys in the store when `resetAll` is called", async () => { + const store = new MemoryStore(); + store.init({ windowMs: minute } as ConfigType); + const keyOne = "test-store-one"; + const keyTwo = "test-store-two"; + + await store.increment(keyOne); + await store.increment(keyTwo); + await store.resetAll(); + + const { totalHits: totalHitsOne } = await store.increment(keyOne); + const { totalHits: totalHitsTwo } = await store.increment(keyTwo); + expect(totalHitsOne).toEqual(1); + expect(totalHitsTwo).toEqual(1); + }); + + it("clears the timer when `shutdown` is called", async () => { + const store = new MemoryStore(); + store.init({ windowMs: minute } as ConfigType); + expect(store.interval).toBeDefined(); + store.shutdown(); + expect(clearInterval).toHaveBeenCalledWith(store.interval); + }); + + it("resets the count for all the keys in the store when the timeout is reached", async () => { + const store = new MemoryStore(); + store.init({ windowMs: 50 } as ConfigType); + const keyOne = "test-store-one"; + const keyTwo = "test-store-two"; + + await store.increment(keyOne); + await store.increment(keyTwo); + + vi.advanceTimersByTime(60); + const { totalHits: totalHitsOne } = await store.increment(keyOne); + const { totalHits: totalHitsTwo } = await store.increment(keyTwo); + expect(totalHitsOne).toEqual(1); + expect(totalHitsTwo).toEqual(1); + }); + + it("can run in electron where setInterval does not return a Timeout object with an unset function", async () => { + const originalSetInterval = setInterval; + let timeoutId = 1; + let realTimeoutId: NodeJS.Timer; + // @ts-expect-error We want to not return a `Timer` instance for testing + vi.spyOn(global, "setTimeout").mockImplementation((callback, timeout) => { + realTimeoutId = originalSetInterval(callback, timeout); + return timeoutId++; + }); + + const store = new MemoryStore(); + store.init({ windowMs: -1 } as ConfigType); + const key = "test-store"; + + try { + const { totalHits } = await store.increment(key); + expect(totalHits).toEqual(1); + } finally { + // @ts-expect-error `realTimeoutId` is already set in the `spyOn` call + clearTimeout(realTimeoutId); + } + }); + + it("should move clients from previous to current", async () => { + const store = new MemoryStore(); + store.init({ windowMs: 100 } as ConfigType); + + await store.increment("key1"); + // Key1 is now in current + expect(store.current.has("key1")).toBe(true); + expect(store.previous.has("key1")).toBe(false); + + vi.advanceTimersByTime(100); + // Key1 is now in previous, current is empty + expect(store.current.has("key1")).toBe(false); + expect(store.previous.has("key1")).toBe(true); + + await store.increment("key1"); + // Should move key from previous to current + expect(store.current.has("key1")).toBe(true); + expect(store.previous.has("key1")).toBe(false); + }); + + // Covers the same bug as above, but in a more robust way that doesn't touch any internal structures + it("does not allow a Client object to be assigned to two keys", async () => { + const store = new MemoryStore(); + store.init({ windowMs: 100 } as ConfigType); + await store.increment("key1"); // Key1 is now in current + + vi.advanceTimersByTime(100); // Key1 is now in previous. Target pool size is 1, but it's empty. + await store.increment("key1"); // Key1 is now in current again. If it's also in previous, that's a bug! + await store.increment("key2"); // Need 1 new client to keep the pool size target at 1 + + vi.advanceTimersByTime(100); // Key1 and key2 are now in previous. Target pool size is 1, but it should be empty. + await store.increment("key1"); // Move it from previous to current + await store.increment("key1"); + let returnValue1 = await store.increment("key1"); + expect(returnValue1.totalHits).toBe(3); + + const returnValue3 = await store.increment("key3"); // Should create a new Client instance because the pool should be empty. In the bad case, it instead resets the same object to 1 + expect(returnValue1).not.toBe(returnValue3); // Should be separate objects + expect(returnValue3.totalHits).toBe(1); + + returnValue1 = await store.increment("key1"); + expect(returnValue1.totalHits).toBe(4); // Should be 4, will be 2 if there's a reuse bug + }); +}); diff --git a/packages/rate-limiter/src/memcache/index.ts b/packages/core/src/memcache/index.ts similarity index 96% rename from packages/rate-limiter/src/memcache/index.ts rename to packages/core/src/memcache/index.ts index 9987240..5aecd71 100644 --- a/packages/rate-limiter/src/memcache/index.ts +++ b/packages/core/src/memcache/index.ts @@ -1,210 +1,210 @@ -import type { ClientRateLimitInfo, ConfigType, Store } from "../types"; - -/** - * The record that stores information about a client - namely, how many times - * they have hit the endpoint, and when their hit count resets. - * - * Similar to `ClientRateLimitInfo`, except `resetTime` is a compulsory field. - */ -type Client = { - totalHits: number; - resetTime: Date; -}; - -/** - * A `Store` that stores the hit count for each client in memory. - * - * @public - */ -export default class MemoryStore implements Store { - /** - * The duration of time before which all hit counts are reset (in milliseconds). - */ - windowMs!: number; - - /** - * These two maps store usage (requests) and reset time by key (for example, IP - * addresses or API keys). - * - * They are split into two to avoid having to iterate through the entire set to - * determine which ones need reset. Instead, `Client`s are moved from `previous` - * to `current` as they hit the endpoint. Once `windowMs` has elapsed, all clients - * left in `previous`, i.e., those that have not made any recent requests, are - * known to be expired and can be deleted in bulk. - */ - previous = new Map(); - current = new Map(); - - /** - * A reference to the active timer. - */ - interval?: NodeJS.Timeout; - - /** - * Confirmation that the keys incremented in once instance of MemoryStore - * cannot affect other instances. - */ - localKeys = true; - - /** - * Method that initializes the store. - * - * @param options {ConfigType} - The options used to setup the middleware. - */ - init(options: ConfigType): void { - // Get the duration of a window from the options. - this.windowMs = options.windowMs; - - // Indicates that init was called more than once. - // Could happen if a store was shared between multiple instances. - if (this.interval) clearInterval(this.interval); - - // Reset all clients left in previous every `windowMs`. - this.interval = setInterval(() => { - this.clearExpired(); - }, this.windowMs); - - // Cleaning up the interval will be taken care of by the `shutdown` method. - if (this.interval.unref) this.interval.unref(); - } - - /** - * Method to fetch a client's hit count and reset time. - * - * @param key {string} - The identifier for a client. - * - * @returns {ClientRateLimitInfo | undefined} - The number of hits and reset time for that client. - * - * @public - */ - async get(key: string): Promise { - return this.current.get(key) ?? this.previous.get(key); - } - - /** - * Method to increment a client's hit counter. - * - * @param key {string} - The identifier for a client. - * - * @returns {ClientRateLimitInfo} - The number of hits and reset time for that client. - * - * @public - */ - async increment(key: string): Promise { - const client = this.getClient(key); - - const now = Date.now(); - if (client.resetTime.getTime() <= now) { - this.resetClient(client, now); - } - - client.totalHits++; - return client; - } - - /** - * Method to decrement a client's hit counter. - * - * @param key {string} - The identifier for a client. - * - * @public - */ - async decrement(key: string): Promise { - const client = this.getClient(key); - - if (client.totalHits > 0) client.totalHits--; - } - - /** - * Method to reset a client's hit counter. - * - * @param key {string} - The identifier for a client. - * - * @public - */ - async resetKey(key: string): Promise { - this.current.delete(key); - this.previous.delete(key); - } - - /** - * Method to reset everyone's hit counter. - * - * @public - */ - async resetAll(): Promise { - this.current.clear(); - this.previous.clear(); - } - - /** - * Method to stop the timer (if currently running) and prevent any memory - * leaks. - * - * @public - */ - shutdown(): void { - clearInterval(this.interval); - void this.resetAll(); - } - - /** - * Recycles a client by setting its hit count to zero, and reset time to - * `windowMs` milliseconds from now. - * - * NOT to be confused with `#resetKey()`, which removes a client from both the - * `current` and `previous` maps. - * - * @param client {Client} - The client to recycle. - * @param now {number} - The current time, to which the `windowMs` is added to get the `resetTime` for the client. - * - * @return {Client} - The modified client that was passed in, to allow for chaining. - */ - private resetClient(client: Client, now = Date.now()): Client { - client.totalHits = 0; - client.resetTime.setTime(now + this.windowMs); - - return client; - } - - /** - * Retrieves or creates a client, given a key. Also ensures that the client being - * returned is in the `current` map. - * - * @param key {string} - The key under which the client is (or is to be) stored. - * - * @returns {Client} - The requested client. - */ - private getClient(key: string): Client { - // If we already have a client for that key in the `current` map, return it. - // biome-ignore lint/style/noNonNullAssertion: safe to use cause already checked - if (this.current.has(key)) return this.current.get(key)!; - - let client: Client; - if (this.previous.has(key)) { - // If it's in the `previous` map, take it out - // biome-ignore lint/style/noNonNullAssertion: safe to use cause already checked - client = this.previous.get(key)!; - this.previous.delete(key); - } else { - // Finally, if we don't have an existing entry for this client, create a new one - client = { totalHits: 0, resetTime: new Date() }; - this.resetClient(client); - } - - // Make sure the client is bumped into the `current` map, and return it. - this.current.set(key, client); - return client; - } - - /** - * Move current clients to previous, create a new map for current. - * - * This function is called every `windowMs`. - */ - private clearExpired(): void { - // At this point, all clients in previous are expired - this.previous = this.current; - this.current = new Map(); - } -} +import type { ClientRateLimitInfo, ConfigType, Store } from "../types"; + +/** + * The record that stores information about a client - namely, how many times + * they have hit the endpoint, and when their hit count resets. + * + * Similar to `ClientRateLimitInfo`, except `resetTime` is a compulsory field. + */ +type Client = { + totalHits: number; + resetTime: Date; +}; + +/** + * A `Store` that stores the hit count for each client in memory. + * + * @public + */ +export default class MemoryStore implements Store { + /** + * The duration of time before which all hit counts are reset (in milliseconds). + */ + windowMs!: number; + + /** + * These two maps store usage (requests) and reset time by key (for example, IP + * addresses or API keys). + * + * They are split into two to avoid having to iterate through the entire set to + * determine which ones need reset. Instead, `Client`s are moved from `previous` + * to `current` as they hit the endpoint. Once `windowMs` has elapsed, all clients + * left in `previous`, i.e., those that have not made any recent requests, are + * known to be expired and can be deleted in bulk. + */ + previous = new Map(); + current = new Map(); + + /** + * A reference to the active timer. + */ + interval?: NodeJS.Timeout; + + /** + * Confirmation that the keys incremented in once instance of MemoryStore + * cannot affect other instances. + */ + localKeys = true; + + /** + * Method that initializes the store. + * + * @param options {ConfigType} - The options used to setup the middleware. + */ + init(options: ConfigType): void { + // Get the duration of a window from the options. + this.windowMs = options.windowMs; + + // Indicates that init was called more than once. + // Could happen if a store was shared between multiple instances. + if (this.interval) clearInterval(this.interval); + + // Reset all clients left in previous every `windowMs`. + this.interval = setInterval(() => { + this.clearExpired(); + }, this.windowMs); + + // Cleaning up the interval will be taken care of by the `shutdown` method. + if (this.interval.unref) this.interval.unref(); + } + + /** + * Method to fetch a client's hit count and reset time. + * + * @param key {string} - The identifier for a client. + * + * @returns {ClientRateLimitInfo | undefined} - The number of hits and reset time for that client. + * + * @public + */ + async get(key: string): Promise { + return this.current.get(key) ?? this.previous.get(key); + } + + /** + * Method to increment a client's hit counter. + * + * @param key {string} - The identifier for a client. + * + * @returns {ClientRateLimitInfo} - The number of hits and reset time for that client. + * + * @public + */ + async increment(key: string): Promise { + const client = this.getClient(key); + + const now = Date.now(); + if (client.resetTime.getTime() <= now) { + this.resetClient(client, now); + } + + client.totalHits++; + return client; + } + + /** + * Method to decrement a client's hit counter. + * + * @param key {string} - The identifier for a client. + * + * @public + */ + async decrement(key: string): Promise { + const client = this.getClient(key); + + if (client.totalHits > 0) client.totalHits--; + } + + /** + * Method to reset a client's hit counter. + * + * @param key {string} - The identifier for a client. + * + * @public + */ + async resetKey(key: string): Promise { + this.current.delete(key); + this.previous.delete(key); + } + + /** + * Method to reset everyone's hit counter. + * + * @public + */ + async resetAll(): Promise { + this.current.clear(); + this.previous.clear(); + } + + /** + * Method to stop the timer (if currently running) and prevent any memory + * leaks. + * + * @public + */ + shutdown(): void { + clearInterval(this.interval); + void this.resetAll(); + } + + /** + * Recycles a client by setting its hit count to zero, and reset time to + * `windowMs` milliseconds from now. + * + * NOT to be confused with `#resetKey()`, which removes a client from both the + * `current` and `previous` maps. + * + * @param client {Client} - The client to recycle. + * @param now {number} - The current time, to which the `windowMs` is added to get the `resetTime` for the client. + * + * @return {Client} - The modified client that was passed in, to allow for chaining. + */ + private resetClient(client: Client, now = Date.now()): Client { + client.totalHits = 0; + client.resetTime.setTime(now + this.windowMs); + + return client; + } + + /** + * Retrieves or creates a client, given a key. Also ensures that the client being + * returned is in the `current` map. + * + * @param key {string} - The key under which the client is (or is to be) stored. + * + * @returns {Client} - The requested client. + */ + private getClient(key: string): Client { + // If we already have a client for that key in the `current` map, return it. + // biome-ignore lint/style/noNonNullAssertion: safe to use cause already checked + if (this.current.has(key)) return this.current.get(key)!; + + let client: Client; + if (this.previous.has(key)) { + // If it's in the `previous` map, take it out + // biome-ignore lint/style/noNonNullAssertion: safe to use cause already checked + client = this.previous.get(key)!; + this.previous.delete(key); + } else { + // Finally, if we don't have an existing entry for this client, create a new one + client = { totalHits: 0, resetTime: new Date() }; + this.resetClient(client); + } + + // Make sure the client is bumped into the `current` map, and return it. + this.current.set(key, client); + return client; + } + + /** + * Move current clients to previous, create a new map for current. + * + * This function is called every `windowMs`. + */ + private clearExpired(): void { + // At this point, all clients in previous are expired + this.previous = this.current; + this.current = new Map(); + } +} diff --git a/packages/rate-limiter/src/types/clientRateLimitInfo.ts b/packages/core/src/types/clientRateLimitInfo.ts similarity index 95% rename from packages/rate-limiter/src/types/clientRateLimitInfo.ts rename to packages/core/src/types/clientRateLimitInfo.ts index 6e35397..7912f55 100644 --- a/packages/rate-limiter/src/types/clientRateLimitInfo.ts +++ b/packages/core/src/types/clientRateLimitInfo.ts @@ -1,4 +1,4 @@ -export type ClientRateLimitInfo = { - totalHits: number; - resetTime?: Date; -}; +export type ClientRateLimitInfo = { + totalHits: number; + resetTime?: Date; +}; diff --git a/packages/rate-limiter/src/types/config.ts b/packages/core/src/types/config.ts similarity index 100% rename from packages/rate-limiter/src/types/config.ts rename to packages/core/src/types/config.ts diff --git a/packages/rate-limiter/src/types/index.ts b/packages/core/src/types/index.ts similarity index 97% rename from packages/rate-limiter/src/types/index.ts rename to packages/core/src/types/index.ts index 7e5ce8b..8a5a68e 100644 --- a/packages/rate-limiter/src/types/index.ts +++ b/packages/core/src/types/index.ts @@ -1,5 +1,5 @@ -export type * from "./store"; -export type * from "./promisify"; -export type * from "./config"; -export type * from "./clientRateLimitInfo"; -export type * from "./rateLimitInfo"; +export type * from "./store"; +export type * from "./promisify"; +export type * from "./config"; +export type * from "./clientRateLimitInfo"; +export type * from "./rateLimitInfo"; diff --git a/packages/rate-limiter/src/types/promisify.ts b/packages/core/src/types/promisify.ts similarity index 97% rename from packages/rate-limiter/src/types/promisify.ts rename to packages/core/src/types/promisify.ts index 2ca0734..55b5c32 100644 --- a/packages/rate-limiter/src/types/promisify.ts +++ b/packages/core/src/types/promisify.ts @@ -1 +1 @@ -export type Promisify = T | Promise; +export type Promisify = T | Promise; diff --git a/packages/rate-limiter/src/types/rateLimitExceededEventHandler.ts b/packages/core/src/types/rateLimitExceededEventHandler.ts similarity index 96% rename from packages/rate-limiter/src/types/rateLimitExceededEventHandler.ts rename to packages/core/src/types/rateLimitExceededEventHandler.ts index e5d2436..e435198 100644 --- a/packages/rate-limiter/src/types/rateLimitExceededEventHandler.ts +++ b/packages/core/src/types/rateLimitExceededEventHandler.ts @@ -1 +1 @@ -import { Context } from "hono"; +import { Context } from "hono"; diff --git a/packages/rate-limiter/src/types/rateLimitInfo.ts b/packages/core/src/types/rateLimitInfo.ts similarity index 95% rename from packages/rate-limiter/src/types/rateLimitInfo.ts rename to packages/core/src/types/rateLimitInfo.ts index 2a3bb2d..7e26a5e 100644 --- a/packages/rate-limiter/src/types/rateLimitInfo.ts +++ b/packages/core/src/types/rateLimitInfo.ts @@ -1,10 +1,10 @@ -/** - * The rate limit related information for each client included in the - * Hono context object. - */ -export type RateLimitInfo = { - limit: number; - used: number; - remaining: number; - resetTime: Date | undefined; -}; +/** + * The rate limit related information for each client included in the + * Hono context object. + */ +export type RateLimitInfo = { + limit: number; + used: number; + remaining: number; + resetTime: Date | undefined; +}; diff --git a/packages/rate-limiter/src/types/store.ts b/packages/core/src/types/store.ts similarity index 96% rename from packages/rate-limiter/src/types/store.ts rename to packages/core/src/types/store.ts index 0bb0a0e..0d0b764 100644 --- a/packages/rate-limiter/src/types/store.ts +++ b/packages/core/src/types/store.ts @@ -1,80 +1,80 @@ -import type { ClientRateLimitInfo } from "./clientRateLimitInfo"; -import type { ConfigType } from "./config"; - -export type IncrementResponse = ClientRateLimitInfo; - -/** - * An interface that all hit counter stores must implement. - */ -export type Store = { - /** - * Method that initializes the store, and has access to the options passed to - * the middleware too. - * - * @param options {ConfigType} - The options used to setup the middleware. - */ - init?: (options: ConfigType) => void; - - /** - * Method to fetch a client's hit count and reset time. - * - * @param key {string} - The identifier for a client. - * - * @returns {ClientRateLimitInfo} - The number of hits and reset time for that client. - */ - get?: ( - key: string, - ) => - | Promise - | ClientRateLimitInfo - | undefined; - - /** - * Method to increment a client's hit counter. - * - * @param key {string} - The identifier for a client. - * - * @returns {IncrementResponse | undefined} - The number of hits and reset time for that client. - */ - increment: (key: string) => Promise | IncrementResponse; - - /** - * Method to decrement a client's hit counter. - * - * @param key {string} - The identifier for a client. - */ - decrement: (key: string) => Promise | void; - - /** - * Method to reset a client's hit counter. - * - * @param key {string} - The identifier for a client. - */ - resetKey: (key: string) => Promise | void; - - /** - * Method to reset everyone's hit counter. - */ - resetAll?: () => Promise | void; - - /** - * Method to shutdown the store, stop timers, and release all resources. - */ - shutdown?: () => Promise | void; - - /** - * Flag to indicate that keys incremented in one instance of this store can - * not affect other instances. Typically false if a database is used, true for - * MemoryStore. - * - * Used to help detect double-counting misconfigurations. - */ - localKeys?: boolean; - - /** - * Optional value that the store prepends to keys - * - * Used by the double-count check to avoid false-positives when a key is counted twice, but with different prefixes - */ - prefix?: string; -}; +import type { ClientRateLimitInfo } from "./clientRateLimitInfo"; +import type { ConfigType } from "./config"; + +export type IncrementResponse = ClientRateLimitInfo; + +/** + * An interface that all hit counter stores must implement. + */ +export type Store = { + /** + * Method that initializes the store, and has access to the options passed to + * the middleware too. + * + * @param options {ConfigType} - The options used to setup the middleware. + */ + init?: (options: ConfigType) => void; + + /** + * Method to fetch a client's hit count and reset time. + * + * @param key {string} - The identifier for a client. + * + * @returns {ClientRateLimitInfo} - The number of hits and reset time for that client. + */ + get?: ( + key: string, + ) => + | Promise + | ClientRateLimitInfo + | undefined; + + /** + * Method to increment a client's hit counter. + * + * @param key {string} - The identifier for a client. + * + * @returns {IncrementResponse | undefined} - The number of hits and reset time for that client. + */ + increment: (key: string) => Promise | IncrementResponse; + + /** + * Method to decrement a client's hit counter. + * + * @param key {string} - The identifier for a client. + */ + decrement: (key: string) => Promise | void; + + /** + * Method to reset a client's hit counter. + * + * @param key {string} - The identifier for a client. + */ + resetKey: (key: string) => Promise | void; + + /** + * Method to reset everyone's hit counter. + */ + resetAll?: () => Promise | void; + + /** + * Method to shutdown the store, stop timers, and release all resources. + */ + shutdown?: () => Promise | void; + + /** + * Flag to indicate that keys incremented in one instance of this store can + * not affect other instances. Typically false if a database is used, true for + * MemoryStore. + * + * Used to help detect double-counting misconfigurations. + */ + localKeys?: boolean; + + /** + * Optional value that the store prepends to keys + * + * Used by the double-count check to avoid false-positives when a key is counted twice, but with different prefixes + */ + prefix?: string; +}; diff --git a/packages/rate-limiter/tsconfig.json b/packages/core/tsconfig.json similarity index 95% rename from packages/rate-limiter/tsconfig.json rename to packages/core/tsconfig.json index f5b8565..55364c8 100644 --- a/packages/rate-limiter/tsconfig.json +++ b/packages/core/tsconfig.json @@ -1,22 +1,22 @@ -{ - "extends": "../../tsconfig.base.json", - "compilerOptions": { - "module": "commonjs", - "forceConsistentCasingInFileNames": true, - "strict": true, - "noImplicitOverride": true, - "noPropertyAccessFromIndexSignature": true, - "noImplicitReturns": true, - "noFallthroughCasesInSwitch": true - }, - "files": [], - "include": [], - "references": [ - { - "path": "./tsconfig.lib.json" - }, - { - "path": "./tsconfig.spec.json" - } - ] -} +{ + "extends": "../../tsconfig.base.json", + "compilerOptions": { + "module": "commonjs", + "forceConsistentCasingInFileNames": true, + "strict": true, + "noImplicitOverride": true, + "noPropertyAccessFromIndexSignature": true, + "noImplicitReturns": true, + "noFallthroughCasesInSwitch": true + }, + "files": [], + "include": [], + "references": [ + { + "path": "./tsconfig.lib.json" + }, + { + "path": "./tsconfig.spec.json" + } + ] +} diff --git a/packages/rate-limiter/tsconfig.lib.json b/packages/core/tsconfig.lib.json similarity index 100% rename from packages/rate-limiter/tsconfig.lib.json rename to packages/core/tsconfig.lib.json diff --git a/packages/rate-limiter/tsconfig.spec.json b/packages/core/tsconfig.spec.json similarity index 95% rename from packages/rate-limiter/tsconfig.spec.json rename to packages/core/tsconfig.spec.json index 3c002c2..e8efde7 100644 --- a/packages/rate-limiter/tsconfig.spec.json +++ b/packages/core/tsconfig.spec.json @@ -1,26 +1,26 @@ -{ - "extends": "./tsconfig.json", - "compilerOptions": { - "outDir": "../../dist/out-tsc", - "types": [ - "vitest/globals", - "vitest/importMeta", - "vite/client", - "node", - "vitest" - ] - }, - "include": [ - "vite.config.ts", - "vitest.config.ts", - "src/**/*.test.ts", - "src/**/*.spec.ts", - "src/**/*.test.tsx", - "src/**/*.spec.tsx", - "src/**/*.test.js", - "src/**/*.spec.js", - "src/**/*.test.jsx", - "src/**/*.spec.jsx", - "src/**/*.d.ts" - ] -} +{ + "extends": "./tsconfig.json", + "compilerOptions": { + "outDir": "../../dist/out-tsc", + "types": [ + "vitest/globals", + "vitest/importMeta", + "vite/client", + "node", + "vitest" + ] + }, + "include": [ + "vite.config.ts", + "vitest.config.ts", + "src/**/*.test.ts", + "src/**/*.spec.ts", + "src/**/*.test.tsx", + "src/**/*.spec.tsx", + "src/**/*.test.js", + "src/**/*.spec.js", + "src/**/*.test.jsx", + "src/**/*.spec.jsx", + "src/**/*.d.ts" + ] +} diff --git a/packages/rate-limiter/vite.config.ts b/packages/core/vite.config.ts similarity index 96% rename from packages/rate-limiter/vite.config.ts rename to packages/core/vite.config.ts index fd10520..7a619e1 100644 --- a/packages/rate-limiter/vite.config.ts +++ b/packages/core/vite.config.ts @@ -1,27 +1,27 @@ -import { defineConfig } from "vite"; - -import { nxViteTsPaths } from "@nx/vite/plugins/nx-tsconfig-paths.plugin"; - -export default defineConfig({ - root: __dirname, - cacheDir: "../../node_modules/.vite/packages/rate-limiter", - - plugins: [nxViteTsPaths()], - - // Uncomment this if you are using workers. - // worker: { - // plugins: [ nxViteTsPaths() ], - // }, - - test: { - globals: true, - cache: { dir: "../../node_modules/.vitest" }, - environment: "node", - include: ["src/**/*.{test,spec}.{js,mjs,cjs,ts,mts,cts,jsx,tsx}"], - reporters: ["default"], - coverage: { - reportsDirectory: "../../coverage/packages/rate-limiter", - provider: "v8", - }, - }, -}); +import { defineConfig } from "vite"; + +import { nxViteTsPaths } from "@nx/vite/plugins/nx-tsconfig-paths.plugin"; + +export default defineConfig({ + root: __dirname, + cacheDir: "../../node_modules/.vite/packages/rate-limiter", + + plugins: [nxViteTsPaths()], + + // Uncomment this if you are using workers. + // worker: { + // plugins: [ nxViteTsPaths() ], + // }, + + test: { + globals: true, + cache: { dir: "../../node_modules/.vitest" }, + environment: "node", + include: ["src/**/*.{test,spec}.{js,mjs,cjs,ts,mts,cts,jsx,tsx}"], + reporters: ["default"], + coverage: { + reportsDirectory: "../../coverage/packages/rate-limiter", + provider: "v8", + }, + }, +}); diff --git a/tsconfig.base.json b/tsconfig.base.json index bbfc70a..e40ba85 100644 --- a/tsconfig.base.json +++ b/tsconfig.base.json @@ -15,7 +15,7 @@ "skipDefaultLibCheck": true, "baseUrl": ".", "paths": { - "hono-rate-limiter": ["packages/rate-limiter/src/index.ts"] + "hono-rate-limiter": ["packages/core/src/index.ts"] } }, "exclude": ["node_modules", "tmp"]