-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
424 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// TreeNode class to represent nodes in the decision tree | ||
class TreeNode { | ||
constructor(value) { | ||
this.value = value; | ||
this.children = []; | ||
} | ||
|
||
addChild(childNode) { | ||
this.children.push(childNode); | ||
} | ||
} | ||
|
||
// DecisionTree class to represent the decision tree structure | ||
class DecisionTree { | ||
constructor() { | ||
this.root = null; | ||
} | ||
|
||
insert(value) { | ||
const newNode = new TreeNode(value); | ||
if (!this.root) { | ||
this.root = newNode; | ||
} else { | ||
this.insertNode(this.root, newNode); | ||
} | ||
} | ||
|
||
insertNode(node, newNode) { | ||
// Insert new node into the tree based on some decision logic. | ||
// This logic depends on your specific use case and dataset. | ||
// For simplicity, we'll just compare the values. | ||
if (newNode.value < node.value) { | ||
if (!node.children[0]) { | ||
node.addChild(newNode); | ||
} else { | ||
this.insertNode(node.children[0], newNode); | ||
} | ||
} else { | ||
if (!node.children[1]) { | ||
node.addChild(newNode); | ||
} else { | ||
this.insertNode(node.children[1], newNode); | ||
} | ||
} | ||
} | ||
|
||
search(value) { | ||
return this.searchNode(this.root, value); | ||
} | ||
|
||
searchNode(node, value) { | ||
if (!node) { | ||
return false; | ||
} | ||
|
||
if (node.value === value) { | ||
return true; | ||
} else if (value < node.value) { | ||
return this.searchNode(node.children[0], value); | ||
} else { | ||
return this.searchNode(node.children[1], value); | ||
} | ||
} | ||
} | ||
|
||
module.exports = DecisionTree; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
class KMeansClustering { | ||
constructor(data, k) { | ||
if (!data || data.length === 0) { | ||
throw new Error('Data must not be empty.'); | ||
} | ||
|
||
this.data = data; // An array of data points | ||
this.k = k; // The number of clusters | ||
|
||
this.centroids = this.initializeCentroids(); | ||
this.clusters = this.initializeClusters(); | ||
} | ||
|
||
initializeCentroids() { | ||
// Randomly initialize centroids | ||
const centroids = []; | ||
const dataCopy = [...this.data]; | ||
for (let i = 0; i < this.k; i++) { | ||
const randomIndex = Math.floor(Math.random() * dataCopy.length); | ||
centroids.push(dataCopy.splice(randomIndex, 1)[0]); | ||
} | ||
return centroids; | ||
} | ||
|
||
initializeClusters() { | ||
// Create clusters for each centroid | ||
return new Array(this.k).fill(null).map(() => []); | ||
} | ||
|
||
euclideanDistance(point1, point2) { | ||
// Calculate the Euclidean distance between two data points | ||
if (point1.length !== point2.length) { | ||
throw new Error('Data points must have the same dimension'); | ||
} | ||
let sum = 0; | ||
for (let i = 0; i < point1.length; i++) { | ||
sum += Math.pow(point1[i] - point2[i], 2); | ||
} | ||
return Math.sqrt(sum); | ||
} | ||
|
||
assignToClusters() { | ||
// Assign data points to the nearest cluster | ||
this.clusters = new Array(this.k).fill(null).map(() => []); | ||
|
||
for (const dataPoint of this.data) { | ||
let minDistance = Number.POSITIVE_INFINITY; | ||
let clusterIndex = 0; | ||
|
||
for (let i = 0; i < this.k; i++) { | ||
const distance = this.euclideanDistance(dataPoint, this.centroids[i]); | ||
if (distance < minDistance) { | ||
minDistance = distance; | ||
clusterIndex = i; | ||
} | ||
} | ||
|
||
this.clusters[clusterIndex].push(dataPoint); | ||
} | ||
} | ||
|
||
updateCentroids() { | ||
// Calculate new centroids for each cluster | ||
for (let i = 0; i < this.k; i++) { | ||
if (this.clusters[i].length === 0) continue; | ||
|
||
const newCentroid = this.clusters[i][0].map((_, j) => { | ||
return ( | ||
this.clusters[i].reduce((sum, dataPoint) => sum + dataPoint[j], 0) / | ||
this.clusters[i].length | ||
); | ||
}); | ||
|
||
this.centroids[i] = newCentroid; | ||
} | ||
} | ||
|
||
cluster() { | ||
// Perform k-means clustering | ||
const maxIterations = 100; | ||
let iterations = 0; | ||
|
||
while (iterations < maxIterations) { | ||
this.assignToClusters(); | ||
const oldCentroids = [...this.centroids]; | ||
this.updateCentroids(); | ||
|
||
// Check for convergence | ||
if (this.centroids.every((centroid, i) => this.euclideanDistance(centroid, oldCentroids[i]) === 0)) { | ||
break; | ||
} | ||
|
||
iterations++; | ||
} | ||
|
||
return this.clusters; | ||
} | ||
} | ||
|
||
module.exports = KMeansClustering; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
class LinearRegression { | ||
constructor() { | ||
this.coefficients = { intercept: 0, slope: 0 }; | ||
} | ||
|
||
// Fit the model to the provided data | ||
fit(data) { | ||
if (data.length < 2) { | ||
throw new Error('Insufficient data for linear regression'); | ||
} | ||
|
||
const n = data.length; | ||
const sumX = data.reduce((acc, [x]) => acc + x, 0); | ||
const sumY = data.reduce((acc, [, y]) => acc + y, 0); | ||
const sumXX = data.reduce((acc, [x]) => acc + x * x, 0); | ||
const sumXY = data.reduce((acc, [x, y]) => acc + x * y, 0); | ||
|
||
const slope = | ||
(n * sumXY - sumX * sumY) / (n * sumXX - sumX * sumX); | ||
const intercept = | ||
(sumY - slope * sumX) / n; | ||
|
||
this.coefficients.slope = slope; | ||
this.coefficients.intercept = intercept; | ||
} | ||
|
||
// Predict the output for a given input | ||
predict(x) { | ||
return this.coefficients.intercept + this.coefficients.slope * x; | ||
} | ||
} | ||
|
||
module.exports = LinearRegression; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
class LogisticRegression { | ||
constructor() { | ||
this.coefficients = { intercept: 0, slope: [] }; | ||
} | ||
|
||
// Fit the model to the provided data | ||
fit(data) { | ||
if (data.length === 0) { | ||
throw new Error('Data must not be empty.'); | ||
} | ||
|
||
if (data[0].length < 2) { | ||
throw new Error('Insufficient features for logistic regression'); | ||
} | ||
|
||
const n = data.length; | ||
const features = data[0].length - 1; | ||
const X = []; | ||
const y = []; | ||
|
||
for (let i = 0; i < n; i++) { | ||
X.push([1].concat(data[i].slice(0, features))); // Include a 1 for the intercept term | ||
y.push(data[i][features]); | ||
} | ||
|
||
const maxIterations = 1000; | ||
const learningRate = 0.1; | ||
|
||
for (let iteration = 0; iteration < maxIterations; iteration++) { | ||
const predictions = this.predict(X); | ||
const errors = []; | ||
|
||
for (let i = 0; i < n; i++) { | ||
errors.push(predictions[i] - y[i]); | ||
} | ||
|
||
const gradient = []; | ||
for (let j = 0; j < features + 1; j++) { | ||
let sum = 0; | ||
for (let i = 0; i < n; i++) { | ||
sum += errors[i] * X[i][j]; | ||
} | ||
gradient.push(sum / n); | ||
} | ||
|
||
this.coefficients.intercept -= learningRate * gradient[0]; | ||
for (let j = 0; j < features; j++) { | ||
this.coefficients.slope[j] -= learningRate * gradient[j + 1]; | ||
} | ||
} | ||
} | ||
|
||
// Predict the probability of the positive class | ||
predict(features) { | ||
const z = this.coefficients.intercept + this.coefficients.slope.reduce((sum, coefficient, j) => sum + coefficient * features[j], 0); | ||
const probability = 1 / (1 + Math.exp(-z)); | ||
return probability; | ||
} | ||
} | ||
|
||
module.exports = LogisticRegression; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
const DecisionTree = require('../../../algorithms/ml-statistical/decisionTrees'); | ||
|
||
// Test the DecisionTree class | ||
describe('DecisionTree', () => { | ||
test('should insert and search values correctly', () => { | ||
const decisionTree = new DecisionTree(); | ||
|
||
// Insert values into the decision tree | ||
decisionTree.insert(10); | ||
decisionTree.insert(5); | ||
decisionTree.insert(15); | ||
decisionTree.insert(2); | ||
decisionTree.insert(7); | ||
|
||
// Search for values in the decision tree | ||
expect(decisionTree.search(7)).toBe(true); | ||
expect(decisionTree.search(20)).toBe(false); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
const KMeansClustering = require('../../../algorithms/ml-statistical/kMeansClustering'); | ||
|
||
describe('KMeansClustering', () => { | ||
it('should cluster data into 3 clusters', () => { | ||
const data = [ | ||
[1, 2], | ||
[2, 3], | ||
[8, 7], | ||
[9, 8], | ||
[11, 13], | ||
[14, 15], | ||
]; | ||
const k = 3; | ||
|
||
const kMeans = new KMeansClustering(data, k); | ||
const clusters = kMeans.cluster(); | ||
|
||
expect(clusters.length).toBe(k); | ||
}); | ||
|
||
it('should return clusters of data points', () => { | ||
const data = [ | ||
[1, 2], | ||
[2, 3], | ||
[8, 7], | ||
[9, 8], | ||
[11, 13], | ||
[14, 15], | ||
]; | ||
const k = 3; | ||
|
||
const kMeans = new KMeansClustering(data, k); | ||
const clusters = kMeans.cluster(); | ||
|
||
expect(clusters).toBeInstanceOf(Array); | ||
clusters.forEach((cluster) => { | ||
expect(cluster).toBeInstanceOf(Array); | ||
}); | ||
}); | ||
|
||
it('should throw an error with empty data', () => { | ||
const data = []; | ||
const k = 3; | ||
expect(() => new KMeansClustering(data, k)).toThrow('Data must not be empty.'); | ||
}); | ||
|
||
it('should handle clusters with a single data point', () => { | ||
const data = [ | ||
[1, 2], | ||
[2, 3], | ||
[8, 7], | ||
]; | ||
const k = 3; | ||
|
||
const kMeans = new KMeansClustering(data, k); | ||
const clusters = kMeans.cluster(); | ||
|
||
expect(clusters.length).toBe(k); | ||
clusters.forEach((cluster) => { | ||
expect(cluster.length).toBe(1); | ||
}); | ||
}); | ||
|
||
it('should handle clusters with multiple data points', () => { | ||
const data = [ | ||
[1, 2], | ||
[2, 3], | ||
[8, 7], | ||
[9, 8], | ||
[11, 13], | ||
[14, 15], | ||
]; | ||
const k = 3; | ||
|
||
const kMeans = new KMeansClustering(data, k); | ||
const clusters = kMeans.cluster(); | ||
|
||
expect(clusters.length).toBe(k); | ||
clusters.forEach((cluster) => { | ||
expect(cluster.length).toBeGreaterThan(1); | ||
}); | ||
}); | ||
}); |
Oops, something went wrong.