diff --git a/src/filterColumns.ts b/src/filterColumns.ts index aec6b0a..f2ba59c 100644 --- a/src/filterColumns.ts +++ b/src/filterColumns.ts @@ -1,12 +1,30 @@ import { CsvTable } from './loadCsv.models'; -const filterColumns = (table: CsvTable, columnNames: string[]) => { - const indexKeepDecisions = table[0].map( - (header) => columnNames.indexOf(header as string) > -1 +/** + * Filters and re-orders columns in a given CSV table. + * + * Where n is the number of non-header cells in `table`, `m` is the number of header cells in `table`, and h is the number of items in `headers` + * + * Time complexity: O(n + mh) + * + * Space complexity: O(n + m + h) + */ +const filterColumns = (table: CsvTable, headers: string[]) => { + const indexKeepDecisions = table[0].map((columnName) => + headers.includes(columnName as string) ); - return table.map((row) => - row.filter((_, index) => indexKeepDecisions[index]) + const filteredColumnNames = table[0].filter((_, i) => indexKeepDecisions[i]); + const indexMap = filteredColumnNames.map((columnName) => + headers.indexOf(columnName as string) ); + return table.map((row) => { + const newRow = new Array(indexMap.length); + for (let i = 0, j = 0; i < row.length; i++) { + if (!indexKeepDecisions[i]) continue; + newRow[indexMap[j++]] = row[i]; + } + return newRow; + }); }; export default filterColumns; diff --git a/src/loadCsv.ts b/src/loadCsv.ts index aa7a0fe..0db7cd8 100644 --- a/src/loadCsv.ts +++ b/src/loadCsv.ts @@ -45,8 +45,6 @@ const loadCsv = ( throw new Error('CSV file can not be shorter than two rows'); } - applyMappings(data, mappings, new Set(flatten)); - const tables: { [key: string]: CsvTable } = { labels: filterColumns(data, labelColumns), features: filterColumns(data, featureColumns), @@ -54,6 +52,11 @@ const loadCsv = ( testLabels: [], }; + const flattenSet = new Set(flatten); + + applyMappings(tables.labels, mappings, flattenSet); + applyMappings(tables.features, mappings, flattenSet); + tables.labels.shift(); const featureColumnNames = tables.features.shift() as string[]; diff --git a/tests/filterColumns.test.ts b/tests/filterColumns.test.ts index a92de95..c3999e8 100644 --- a/tests/filterColumns.test.ts +++ b/tests/filterColumns.test.ts @@ -2,8 +2,8 @@ import filterColumns from '../src/filterColumns'; const data = [ ['lat', 'lng', 'country'], - ['0.234', '1.47', 'SomeCountria'], - ['-293.2', '103.34', 'SomeOtherCountria'], + ['0', '1.47', 'SomeCountria'], + ['1', '103.34', 'SomeOtherCountria'], ]; test('Filtering a single column works correctly', () => { @@ -11,11 +11,11 @@ test('Filtering a single column works correctly', () => { expect(result).toMatchObject([['lng'], ['1.47'], ['103.34']]); }); -test('Filtering multiple columns works correctly', () => { - const result = filterColumns(data, ['country', 'lng']); // Column order from the CSV should be preserved. +test('Filtering multiple columns works correctly, respects order in second argument, does not break with multiple same name columns', () => { + const result = filterColumns(data, ['country', 'lat']); // Column order from the CSV should be preserved. expect(result).toMatchObject([ - ['lng', 'country'], - ['1.47', 'SomeCountria'], - ['103.34', 'SomeOtherCountria'], + ['country', 'lat'], + ['SomeCountria', '0'], + ['SomeOtherCountria', '1'], ]); });