diff --git a/src/function/matrix/kron.js b/src/function/matrix/kron.js index a7a92b1adc..4390334a70 100644 --- a/src/function/matrix/kron.js +++ b/src/function/matrix/kron.js @@ -72,6 +72,11 @@ export const createKron = /* #__PURE__ */ factory(name, dependencies, ({ typed, const t = [] let r = [] + // Check if both a and b are of length 1 + if (a.length === 1 && b.length === 1) { + return a[0].flatMap((y) => b[0].map((x) => multiplyScalar(y, x))) + } + return a.map(function (a) { return b.map(function (b) { r = [] diff --git a/test/unit-tests/function/matrix/kron.test.js b/test/unit-tests/function/matrix/kron.test.js index 56af01123d..74470ce8b8 100644 --- a/test/unit-tests/function/matrix/kron.test.js +++ b/test/unit-tests/function/matrix/kron.test.js @@ -29,13 +29,12 @@ describe('kron', function () { }) it('should calculate product for empty 2D Arrays', function () { - assert.deepStrictEqual(math.kron([[]], [[]]), [[]]) + assert.deepStrictEqual(math.kron([[]], [[]]), []) }) - it('should calculate product for 1D Arrays', function () { assert.deepStrictEqual(math.kron([1, 1], [[1, 0], [0, 1]]), [[1, 0, 1, 0], [0, 1, 0, 1]]) assert.deepStrictEqual(math.kron([[1, 0], [0, 1]], [1, 1]), [[1, 1, 0, 0], [0, 0, 1, 1]]) - assert.deepStrictEqual(math.kron([1, 2, 6, 8], [12, 1, 2, 3]), [[12, 1, 2, 3, 24, 2, 4, 6, 72, 6, 12, 18, 96, 8, 16, 24]]) + assert.deepStrictEqual(math.kron([1, 2, 6, 8], [12, 1, 2, 3]), [12, 1, 2, 3, 24, 2, 4, 6, 72, 6, 12, 18, 96, 8, 16, 24]) }) it('should support complex numbers', function () {