Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 1 addition & 18 deletions src/loadCsv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,10 @@ import { shuffle } from 'shuffle-seed';

import { CsvReadOptions, CsvTable } from './loadCsv.models';
import filterColumns from './filterColumns';
import splitTestData from './splitTestData';

const defaultShuffleSeed = 'mncv9340ur';

const splitTestData = (
features: CsvTable,
labels: CsvTable,
splitTest: boolean | number
) => {
const length =
typeof splitTest === 'number'
? Math.max(0, Math.min(splitTest, features.length - 1))
: Math.floor(features.length / 2);

return {
testFeatures: features.slice(length),
testLabels: labels.slice(length),
features: features.slice(0, length),
labels: labels.slice(0, length),
};
};

const loadCsv = (filename: string, options: CsvReadOptions) => {
const {
featureColumns,
Expand Down
23 changes: 23 additions & 0 deletions src/splitTestData.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import { CsvTable } from './loadCsv.models';

const splitTestData = (
features: CsvTable,
labels: CsvTable,
splitTest: true | number
) => {
const dataLength = features.length;
const testLength =
typeof splitTest === 'number'
? Math.max(0, Math.min(splitTest, dataLength))
: Math.floor(features.length / 2);
const testStartIndex = dataLength - testLength;

return {
features: features.slice(0, testStartIndex),
labels: labels.slice(0, testStartIndex),
testFeatures: features.slice(testStartIndex),
testLabels: labels.slice(testStartIndex),
};
};

export default splitTestData;
81 changes: 81 additions & 0 deletions tests/splitTestData.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import splitTestData from '../src/splitTestData';

const tables = {
features: [
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
labels: [[9], [10], [11], [12]],
};

test('Default splitting, splits in half', () => {
const { features, labels, testFeatures, testLabels } = splitTestData(
tables.features,
tables.labels,
true
);
expect(features).toMatchObject([
[1, 2],
[3, 4],
]);
expect(labels).toMatchObject([[9], [10]]);
expect(testFeatures).toMatchObject([
[5, 6],
[7, 8],
]);
expect(testLabels).toMatchObject([[11], [12]]);
});

test('Splitting a fixed amount works', () => {
const { features, labels, testFeatures, testLabels } = splitTestData(
tables.features,
tables.labels,
1
);
expect(features).toMatchObject([
[1, 2],
[3, 4],
[5, 6],
]);
expect(labels).toMatchObject([[9], [10], [11]]);
expect(testFeatures).toMatchObject([[7, 8]]);
expect(testLabels).toMatchObject([[12]]);
});

test('Splitting more than row length splits all rows into test data', () => {
const { features, labels, testFeatures, testLabels } = splitTestData(
tables.features,
tables.labels,
tables.features.length * 2
);
expect(features).toMatchObject([]);
expect(labels).toMatchObject([]);
expect(testFeatures).toMatchObject([
[1, 2],
[3, 4],
[5, 6],
[7, 8],
]);
expect(testLabels).toMatchObject([[9], [10], [11], [12]]);
});

test('Splitting less than or equal to 0 places all rows into normal data', () => {
[0, -1].forEach((splitLength) => {
const { features, labels, testFeatures, testLabels } = splitTestData(
tables.features,
tables.labels,
splitLength
);
expect(features).toMatchObject([
[1, 2],
[3, 4],
[5, 6],
[7, 8],
]);
expect(labels).toMatchObject([[9], [10], [11], [12]]);
expect(testFeatures).toMatchObject([]);
expect(testLabels).toMatchObject([]);
});
});