Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(middleware-flexible-checksums): perform checksum calculation and validation by default #6750

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions packages/middleware-flexible-checksums/src/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ import {
Encoder,
GetAwsChunkedEncodingStream,
HashConstructor,
Provider,
StreamCollector,
StreamHasher,
} from "@smithy/types";

import { RequestChecksumCalculation, ResponseChecksumValidation } from "./constants";

export interface PreviouslyResolved {
/**
* The function that will be used to convert binary data to a base64-encoded string.
Expand All @@ -31,6 +34,16 @@ export interface PreviouslyResolved {
*/
md5: ChecksumConstructor | HashConstructor;

/**
* Determines when a checksum will be calculated for request payloads
*/
requestChecksumCalculation: Provider<RequestChecksumCalculation>;

/**
* Determines when a checksum will be calculated for response payloads
*/
responseChecksumValidation: Provider<ResponseChecksumValidation>;

/**
* A constructor for a class implementing the {@link Hash} interface that computes SHA1 hashes.
* @internal
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import { setFeature } from "@aws-sdk/core";
import { afterEach, describe, expect, test as it, vi } from "vitest";

import { PreviouslyResolved } from "./configuration";
import { DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation, ResponseChecksumValidation } from "./constants";
import { flexibleChecksumsInputMiddleware } from "./flexibleChecksumsInputMiddleware";

vi.mock("@aws-sdk/core");

describe(flexibleChecksumsInputMiddleware.name, () => {
const mockNext = vi.fn();
const mockRequestValidationModeMember = "mockRequestValidationModeMember";

const mockConfig = {
requestChecksumCalculation: () => Promise.resolve(RequestChecksumCalculation.WHEN_SUPPORTED),
responseChecksumValidation: () => Promise.resolve(ResponseChecksumValidation.WHEN_SUPPORTED),
} as PreviouslyResolved;

afterEach(() => {
expect(mockNext).toHaveBeenCalledTimes(1);
vi.clearAllMocks();
});

describe("sets input.requestValidationModeMember", () => {
it("when requestValidationModeMember is defined and responseChecksumValidation is supported", async () => {
const mockMiddlewareConfigWithMockRequestValidationModeMember = {
requestValidationModeMember: mockRequestValidationModeMember,
};
const handler = flexibleChecksumsInputMiddleware(
mockConfig,
mockMiddlewareConfigWithMockRequestValidationModeMember
)(mockNext, {});
await handler({ input: {} });
expect(mockNext).toHaveBeenCalledWith({ input: { [mockRequestValidationModeMember]: "ENABLED" } });
});
});

describe("leaves input.requestValidationModeMember", () => {
const mockArgs = { input: {} };

it("when requestValidationModeMember is not defined", async () => {
const handler = flexibleChecksumsInputMiddleware(mockConfig, {})(mockNext, {});
await handler(mockArgs);
expect(mockNext).toHaveBeenCalledWith(mockArgs);
});

it("when responseChecksumValidation is required", async () => {
const mockConfigResWhenRequired = {
...mockConfig,
responseChecksumValidation: () => Promise.resolve(ResponseChecksumValidation.WHEN_REQUIRED),
} as PreviouslyResolved;

const handler = flexibleChecksumsInputMiddleware(mockConfigResWhenRequired, {})(mockNext, {});
await handler(mockArgs);

expect(mockNext).toHaveBeenCalledWith(mockArgs);
});
});

describe("set feature", () => {
it.each([
[
"FLEXIBLE_CHECKSUMS_REQ_WHEN_REQUIRED",
"a",
"requestChecksumCalculation",
RequestChecksumCalculation.WHEN_REQUIRED,
],
[
"FLEXIBLE_CHECKSUMS_REQ_WHEN_SUPPORTED",
"Z",
"requestChecksumCalculation",
RequestChecksumCalculation.WHEN_SUPPORTED,
],
[
"FLEXIBLE_CHECKSUMS_RES_WHEN_REQUIRED",
"c",
"responseChecksumValidation",
ResponseChecksumValidation.WHEN_REQUIRED,
],
[
"FLEXIBLE_CHECKSUMS_RES_WHEN_SUPPORTED",
"b",
"responseChecksumValidation",
ResponseChecksumValidation.WHEN_SUPPORTED,
],
])("logs %s:%s when %s=%s", async (feature, value, configKey, configValue) => {
const mockConfigOverride = {
...mockConfig,
[configKey]: () => Promise.resolve(configValue),
} as PreviouslyResolved;

const handler = flexibleChecksumsInputMiddleware(mockConfigOverride, {})(mockNext, {});
await handler({ input: {} });

expect(setFeature).toHaveBeenCalledTimes(2);
if (configKey === "requestChecksumCalculation") {
expect(setFeature).toHaveBeenNthCalledWith(1, expect.anything(), feature, value);
} else {
expect(setFeature).toHaveBeenNthCalledWith(2, expect.anything(), feature, value);
}
});
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import { setFeature } from "@aws-sdk/core";
import {
HandlerExecutionContext,
MetadataBearer,
RelativeMiddlewareOptions,
SerializeHandler,
SerializeHandlerArguments,
SerializeHandlerOutput,
SerializeMiddleware,
} from "@smithy/types";

import { PreviouslyResolved } from "./configuration";
import { RequestChecksumCalculation, ResponseChecksumValidation } from "./constants";

export interface FlexibleChecksumsInputMiddlewareConfig {
/**
* Defines a top-level operation input member used to opt-in to best-effort validation
* of a checksum returned in the HTTP response of the operation.
*/
requestValidationModeMember?: string;
}

/**
* @internal
*/
export const flexibleChecksumsInputMiddlewareOptions: RelativeMiddlewareOptions = {
name: "flexibleChecksumsInputMiddleware",
toMiddleware: "serializerMiddleware",
relation: "before",
tags: ["BODY_CHECKSUM"],
override: true,
};

/**
* @internal
*
* The input counterpart to the flexibleChecksumsMiddleware.
*/
export const flexibleChecksumsInputMiddleware =
(
config: PreviouslyResolved,
middlewareConfig: FlexibleChecksumsInputMiddlewareConfig
): SerializeMiddleware<any, any> =>
<Output extends MetadataBearer>(
next: SerializeHandler<any, Output>,
context: HandlerExecutionContext
): SerializeHandler<any, Output> =>
async (args: SerializeHandlerArguments<any>): Promise<SerializeHandlerOutput<Output>> => {
const input = args.input;
const { requestValidationModeMember } = middlewareConfig;

const requestChecksumCalculation = await config.requestChecksumCalculation();
const responseChecksumValidation = await config.responseChecksumValidation();

switch (requestChecksumCalculation) {
case RequestChecksumCalculation.WHEN_REQUIRED:
setFeature(context, "FLEXIBLE_CHECKSUMS_REQ_WHEN_REQUIRED", "a");
break;
case RequestChecksumCalculation.WHEN_SUPPORTED:
setFeature(context, "FLEXIBLE_CHECKSUMS_REQ_WHEN_SUPPORTED", "Z");
break;
}

switch (responseChecksumValidation) {
case ResponseChecksumValidation.WHEN_REQUIRED:
setFeature(context, "FLEXIBLE_CHECKSUMS_RES_WHEN_REQUIRED", "c");
break;
case ResponseChecksumValidation.WHEN_SUPPORTED:
setFeature(context, "FLEXIBLE_CHECKSUMS_RES_WHEN_SUPPORTED", "b");
break;
}

// The value for input member to opt-in to best-effort validation of a checksum returned in the HTTP response is not set.
if (requestValidationModeMember && !input[requestValidationModeMember]) {
// Set requestValidationModeMember as ENABLED only if response checksum validation is supported.
if (responseChecksumValidation === ResponseChecksumValidation.WHEN_SUPPORTED) {
input[requestValidationModeMember] = "ENABLED";
}
}

return next(args);
};
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { BuildHandlerArguments } from "@smithy/types";
import { afterEach, beforeEach, describe, expect, test as it, vi } from "vitest";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { ChecksumAlgorithm, DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation } from "./constants";
import { flexibleChecksumsMiddleware } from "./flexibleChecksumsMiddleware";
import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest";
import { getChecksumLocationName } from "./getChecksumLocationName";
Expand All @@ -13,6 +13,7 @@ import { isStreaming } from "./isStreaming";
import { selectChecksumAlgorithmFunction } from "./selectChecksumAlgorithmFunction";
import { stringHasher } from "./stringHasher";

vi.mock("@aws-sdk/core");
vi.mock("@smithy/protocol-http");
vi.mock("./getChecksumAlgorithmForRequest");
vi.mock("./getChecksumLocationName");
Expand All @@ -28,10 +29,14 @@ describe(flexibleChecksumsMiddleware.name, () => {
const mockChecksum = "mockChecksum";
const mockChecksumAlgorithmFunction = vi.fn();
const mockChecksumLocationName = "mock-checksum-location-name";
const mockRequestAlgorithmMember = "mockRequestAlgorithmMember";
const mockRequestAlgorithmMemberHttpHeader = "mock-request-algorithm-member-http-header";

const mockInput = {};
const mockConfig = {} as PreviouslyResolved;
const mockMiddlewareConfig = { requestChecksumRequired: false };
const mockConfig = {
requestChecksumCalculation: () => Promise.resolve(RequestChecksumCalculation.WHEN_REQUIRED),
} as PreviouslyResolved;
const mockMiddlewareConfig = { input: mockInput, requestChecksumRequired: false };

const mockBody = { body: "mockRequestBody" };
const mockHeaders = { "content-length": 100, "content-encoding": "gzip" };
Expand All @@ -41,9 +46,8 @@ describe(flexibleChecksumsMiddleware.name, () => {

beforeEach(() => {
mockNext.mockResolvedValueOnce(mockResult);
const { isInstance } = HttpRequest;
(isInstance as unknown as any).mockReturnValue(true);
vi.mocked(getChecksumAlgorithmForRequest).mockReturnValue(ChecksumAlgorithm.MD5);
vi.mocked(HttpRequest.isInstance).mockReturnValue(true);
vi.mocked(getChecksumAlgorithmForRequest).mockReturnValue(ChecksumAlgorithm.CRC32);
vi.mocked(getChecksumLocationName).mockReturnValue(mockChecksumLocationName);
vi.mocked(hasHeader).mockReturnValue(true);
vi.mocked(hasHeaderWithPrefix).mockReturnValue(false);
Expand All @@ -58,8 +62,7 @@ describe(flexibleChecksumsMiddleware.name, () => {

describe("skips", () => {
it("if not an instance of HttpRequest", async () => {
const { isInstance } = HttpRequest;
(isInstance as unknown as any).mockReturnValue(false);
vi.mocked(HttpRequest.isInstance).mockReturnValue(false);
const handler = flexibleChecksumsMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {});
await handler(mockArgs);
expect(getChecksumAlgorithmForRequest).not.toHaveBeenCalled();
Expand All @@ -77,7 +80,7 @@ describe(flexibleChecksumsMiddleware.name, () => {
expect(getChecksumAlgorithmForRequest).toHaveBeenCalledTimes(1);
});

it("if header is already present", async () => {
it("skip if header is already present", async () => {
const handler = flexibleChecksumsMiddleware(mockConfig, mockMiddlewareConfig)(mockNext, {});
vi.mocked(hasHeaderWithPrefix).mockReturnValue(true);

Expand All @@ -94,11 +97,53 @@ describe(flexibleChecksumsMiddleware.name, () => {

describe("adds checksum in the request header", () => {
afterEach(() => {
expect(HttpRequest.isInstance).toHaveBeenCalledTimes(1);
expect(hasHeaderWithPrefix).toHaveBeenCalledTimes(1);
expect(getChecksumAlgorithmForRequest).toHaveBeenCalledTimes(1);
expect(getChecksumLocationName).toHaveBeenCalledTimes(1);
expect(selectChecksumAlgorithmFunction).toHaveBeenCalledTimes(1);
});

describe("if input.requestAlgorithmMember can be set", () => {
describe("input[requestAlgorithmMember] is not defined and", () => {
const mockMwConfigWithReqAlgoMember = {
...mockMiddlewareConfig,
requestAlgorithmMember: {
name: mockRequestAlgorithmMember,
httpHeader: mockRequestAlgorithmMemberHttpHeader,
},
};

it("requestChecksumCalculation is supported", async () => {
const handler = flexibleChecksumsMiddleware(
{
...mockConfig,
requestChecksumCalculation: () => Promise.resolve(RequestChecksumCalculation.WHEN_SUPPORTED),
},
mockMwConfigWithReqAlgoMember
)(mockNext, {});
await handler(mockArgs);
expect(mockNext.mock.calls[0][0].input[mockRequestAlgorithmMember]).toEqual(DEFAULT_CHECKSUM_ALGORITHM);
expect(mockNext.mock.calls[0][0].request.headers[mockRequestAlgorithmMemberHttpHeader]).toEqual(
DEFAULT_CHECKSUM_ALGORITHM
);
});

it("requestChecksumRequired is set to true", async () => {
const handler = flexibleChecksumsMiddleware(mockConfig, {
...mockMwConfigWithReqAlgoMember,
requestChecksumRequired: true,
})(mockNext, {});

await handler(mockArgs);
expect(mockNext.mock.calls[0][0].input[mockRequestAlgorithmMember]).toEqual(DEFAULT_CHECKSUM_ALGORITHM);
expect(mockNext.mock.calls[0][0].request.headers[mockRequestAlgorithmMemberHttpHeader]).toEqual(
DEFAULT_CHECKSUM_ALGORITHM
);
});
});
});

it("for streaming body", async () => {
vi.mocked(isStreaming).mockReturnValue(true);
const mockUpdatedBody = { body: "mockUpdatedBody" };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
} from "@smithy/types";

import { PreviouslyResolved } from "./configuration";
import { ChecksumAlgorithm } from "./constants";
import { ChecksumAlgorithm, DEFAULT_CHECKSUM_ALGORITHM, RequestChecksumCalculation } from "./constants";
import { getChecksumAlgorithmForRequest } from "./getChecksumAlgorithmForRequest";
import { getChecksumLocationName } from "./getChecksumLocationName";
import { hasHeader } from "./hasHeader";
Expand Down Expand Up @@ -73,10 +73,26 @@ export const flexibleChecksumsMiddleware =
const { body: requestBody, headers } = request;
const { base64Encoder, streamHasher } = config;
const { requestChecksumRequired, requestAlgorithmMember } = middlewareConfig;
const requestChecksumCalculation = await config.requestChecksumCalculation();

const requestAlgorithmMemberName = requestAlgorithmMember?.name;
const requestAlgorithmMemberHttpHeader = requestAlgorithmMember?.httpHeader;
// The value for input member to configure flexible checksum is not set.
if (requestAlgorithmMemberName && !input[requestAlgorithmMemberName]) {
// Set requestAlgorithmMember as default checksum algorithm only if request checksum calculation is supported
// or request checksum is required.
if (requestChecksumCalculation === RequestChecksumCalculation.WHEN_SUPPORTED || requestChecksumRequired) {
input[requestAlgorithmMemberName] = DEFAULT_CHECKSUM_ALGORITHM;
if (requestAlgorithmMemberHttpHeader) {
headers[requestAlgorithmMemberHttpHeader] = DEFAULT_CHECKSUM_ALGORITHM;
}
}
}

const checksumAlgorithm = getChecksumAlgorithmForRequest(input, {
requestChecksumRequired,
requestAlgorithmMember: requestAlgorithmMember?.name,
requestChecksumCalculation,
});
let updatedBody = requestBody;
let updatedHeaders = headers;
Expand Down
Loading
Loading