Skip to content

Commit 0a86384

Browse files
feat: Add bigframes.pandas.crosstab
1 parent ef5e83a commit 0a86384

File tree

7 files changed

+261
-3
lines changed

7 files changed

+261
-3
lines changed

bigframes/core/reshape/api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from bigframes.core.reshape.concat import concat
1616
from bigframes.core.reshape.encoding import get_dummies
1717
from bigframes.core.reshape.merge import merge
18+
from bigframes.core.reshape.pivot import crosstab
1819
from bigframes.core.reshape.tile import cut, qcut
1920

20-
__all__ = ["concat", "get_dummies", "merge", "cut", "qcut"]
21+
__all__ = ["concat", "get_dummies", "merge", "cut", "qcut", "crosstab"]

bigframes/core/reshape/pivot.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from typing import Optional, TYPE_CHECKING
17+
18+
import bigframes_vendored.pandas.core.reshape.pivot as vendored_pandas_pivot
19+
import pandas as pd
20+
21+
import bigframes
22+
from bigframes.core import convert, utils
23+
from bigframes.core.reshape import concat
24+
from bigframes.dataframe import DataFrame
25+
26+
if TYPE_CHECKING:
27+
import bigframes.session
28+
29+
30+
def crosstab(
31+
index,
32+
columns,
33+
values=None,
34+
rownames=None,
35+
colnames=None,
36+
aggfunc=None,
37+
*,
38+
session: Optional[bigframes.session.Session] = None,
39+
) -> DataFrame:
40+
if _is_list_of_lists(index):
41+
index = [
42+
convert.to_bf_series(subindex, default_index=None, session=session)
43+
for subindex in index
44+
]
45+
else:
46+
index = [convert.to_bf_series(index, default_index=None, session=session)]
47+
if _is_list_of_lists(columns):
48+
columns = [
49+
convert.to_bf_series(subcol, default_index=None, session=session)
50+
for subcol in columns
51+
]
52+
else:
53+
columns = [convert.to_bf_series(columns, default_index=None, session=session)]
54+
55+
df = concat.concat([*index, *columns], join="inner", axis=1)
56+
# for uniqueness
57+
tmp_index_names = [f"_crosstab_index_{i}" for i in range(len(index))]
58+
tmp_col_names = [f"_crosstab_columns_{i}" for i in range(len(columns))]
59+
df.columns = pd.Index([*tmp_index_names, *tmp_col_names])
60+
61+
values = (
62+
convert.to_bf_series(values, default_index=df.index, session=session)
63+
if values is not None
64+
else 0
65+
)
66+
67+
df["_crosstab_values"] = values
68+
pivot_table = df.pivot_table(
69+
values="_crosstab_values",
70+
index=tmp_index_names,
71+
columns=tmp_col_names,
72+
aggfunc=aggfunc or "count",
73+
sort=False,
74+
)
75+
pivot_table.index.names = rownames or [i.name for i in index]
76+
pivot_table.columns.names = colnames or [c.name for c in columns]
77+
if aggfunc is None:
78+
# TODO: Push this into pivot_table itself
79+
pivot_table = pivot_table.fillna(0)
80+
return pivot_table
81+
82+
83+
def _is_list_of_lists(item) -> bool:
84+
if not utils.is_list_like(item):
85+
return False
86+
return all(convert.can_convert_to_series(subitem) for subitem in item)
87+
88+
89+
crosstab.__doc__ = vendored_pandas_pivot.crosstab.__doc__

bigframes/dataframe.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3479,7 +3479,34 @@ def pivot_table(
34793479
] = None,
34803480
columns: typing.Union[blocks.Label, Sequence[blocks.Label]] = None,
34813481
aggfunc: str = "mean",
3482+
fill_value=None,
3483+
margins: bool = False,
3484+
dropna: bool = True,
3485+
margins_name: Hashable = "All",
3486+
observed: bool = False,
3487+
sort: bool = True,
34823488
) -> DataFrame:
3489+
if fill_value is not None:
3490+
raise NotImplementedError(
3491+
"DataFrame.pivot_table fill_value arg not supported. {constants.FEEDBACK_LINK}"
3492+
)
3493+
if margins:
3494+
raise NotImplementedError(
3495+
"DataFrame.pivot_table margins arg not supported. {constants.FEEDBACK_LINK}"
3496+
)
3497+
if not dropna:
3498+
raise NotImplementedError(
3499+
"DataFrame.pivot_table dropna arg not supported. {constants.FEEDBACK_LINK}"
3500+
)
3501+
if margins_name != "All":
3502+
raise NotImplementedError(
3503+
"DataFrame.pivot_table margins_name arg not supported. {constants.FEEDBACK_LINK}"
3504+
)
3505+
if observed:
3506+
raise NotImplementedError(
3507+
"DataFrame.pivot_table observed arg not supported. {constants.FEEDBACK_LINK}"
3508+
)
3509+
34833510
if isinstance(index, Iterable) and not (
34843511
isinstance(index, blocks.Label) and index in self.columns
34853512
):
@@ -3521,7 +3548,9 @@ def pivot_table(
35213548
columns=columns,
35223549
index=index,
35233550
values=values if len(values) > 1 else None,
3524-
).sort_index()
3551+
)
3552+
if sort:
3553+
pivoted = pivoted.sort_index()
35253554

35263555
# TODO: Remove the reordering step once the issue is resolved.
35273556
# The pivot_table method results in multi-index columns that are always ordered.

bigframes/pandas/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import bigframes.core.blocks
3232
import bigframes.core.global_session as global_session
3333
import bigframes.core.indexes
34-
from bigframes.core.reshape.api import concat, cut, get_dummies, merge, qcut
34+
from bigframes.core.reshape.api import concat, crosstab, cut, get_dummies, merge, qcut
3535
import bigframes.core.tools
3636
import bigframes.dataframe
3737
import bigframes.enums
@@ -372,6 +372,7 @@ def reset_session():
372372
_functions = [
373373
clean_up_by_session_id,
374374
concat,
375+
crosstab,
375376
cut,
376377
deploy_remote_function,
377378
deploy_udf,

bigframes/session/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,6 +2312,21 @@ def cut(self, *args, **kwargs) -> bigframes.series.Series:
23122312
**kwargs,
23132313
)
23142314

2315+
def crosstab(self, *args, **kwargs) -> dataframe.DataFrame:
2316+
"""Compute a simple cross tabulation of two (or more) factors.
2317+
2318+
Included for compatibility between bpd and Session.
2319+
2320+
See :func:`bigframes.pandas.crosstab` for full documentation.
2321+
"""
2322+
import bigframes.core.reshape.pivot
2323+
2324+
return bigframes.core.reshape.pivot.crosstab(
2325+
*args,
2326+
session=self,
2327+
**kwargs,
2328+
)
2329+
23152330
def DataFrame(self, *args, **kwargs):
23162331
"""Constructs a DataFrame.
23172332

tests/system/small/test_pandas.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,72 @@ def test_merge_raises_error_when_left_right_on_set(scalars_dfs):
454454
)
455455

456456

457+
def test_crosstab_aligned_series(scalars_dfs):
458+
scalars_df, scalars_pandas_df = scalars_dfs
459+
460+
pd_result = pd.crosstab(
461+
scalars_pandas_df["int64_col"], scalars_pandas_df["int64_too"]
462+
)
463+
bf_result = bpd.crosstab(
464+
scalars_df["int64_col"], scalars_df["int64_too"]
465+
).to_pandas()
466+
467+
assert_pandas_df_equal(bf_result, pd_result, check_dtype=False)
468+
469+
470+
def test_crosstab_nondefault_func(scalars_dfs):
471+
scalars_df, scalars_pandas_df = scalars_dfs
472+
473+
pd_result = pd.crosstab(
474+
scalars_pandas_df["int64_col"],
475+
scalars_pandas_df["int64_too"],
476+
values=scalars_pandas_df["float64_col"],
477+
aggfunc="mean",
478+
)
479+
bf_result = bpd.crosstab(
480+
scalars_df["int64_col"],
481+
scalars_df["int64_too"],
482+
values=scalars_df["float64_col"],
483+
aggfunc="mean",
484+
).to_pandas()
485+
486+
assert_pandas_df_equal(bf_result, pd_result, check_dtype=False)
487+
488+
489+
def test_crosstab_multi_cols(scalars_dfs):
490+
scalars_df, scalars_pandas_df = scalars_dfs
491+
492+
pd_result = pd.crosstab(
493+
[scalars_pandas_df["int64_col"], scalars_pandas_df["bool_col"]],
494+
[scalars_pandas_df["int64_too"], scalars_pandas_df["string_col"]],
495+
rownames=["a", "b"],
496+
colnames=["c", "d"],
497+
)
498+
bf_result = bpd.crosstab(
499+
[scalars_df["int64_col"], scalars_df["bool_col"]],
500+
[scalars_df["int64_too"], scalars_df["string_col"]],
501+
rownames=["a", "b"],
502+
colnames=["c", "d"],
503+
).to_pandas()
504+
505+
assert_pandas_df_equal(bf_result, pd_result, check_dtype=False)
506+
507+
508+
def test_crosstab_unaligned_series(scalars_dfs, session):
509+
scalars_df, scalars_pandas_df = scalars_dfs
510+
other_pd_series = pd.Series(
511+
[10, 20, 10, 30, 10], index=[5, 4, 1, 2, 3], dtype="Int64", name="nums"
512+
)
513+
other_bf_series = session.Series(
514+
[10, 20, 10, 30, 10], index=[5, 4, 1, 2, 3], name="nums"
515+
)
516+
517+
pd_result = pd.crosstab(scalars_pandas_df["int64_col"], other_pd_series)
518+
bf_result = bpd.crosstab(scalars_df["int64_col"], other_bf_series).to_pandas()
519+
520+
assert_pandas_df_equal(bf_result, pd_result, check_dtype=False)
521+
522+
457523
def _convert_pandas_category(pd_s: pd.Series):
458524
"""
459525
Transforms a pandas Series with Categorical dtype into a bigframes-compatible
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Contains code from https://github.com/pandas-dev/pandas/blob/main/pandas/core/reshape/pivot.py
2+
from __future__ import annotations
3+
4+
from bigframes import constants
5+
6+
7+
def crosstab(
8+
index,
9+
columns,
10+
values=None,
11+
rownames=None,
12+
colnames=None,
13+
aggfunc=None,
14+
):
15+
"""
16+
Compute a simple cross tabulation of two (or more) factors.
17+
18+
By default, computes a frequency table of the factors unless an
19+
array of values and an aggregation function are passed.
20+
21+
**Examples:**
22+
>>> a = np.array(["foo", "foo", "foo", "foo", "bar", "bar",
23+
... "bar", "bar", "foo", "foo", "foo"], dtype=object)
24+
>>> b = np.array(["one", "one", "one", "two", "one", "one",
25+
... "one", "two", "two", "two", "one"], dtype=object)
26+
>>> c = np.array(["dull", "dull", "shiny", "dull", "dull", "shiny",
27+
... "shiny", "dull", "shiny", "shiny", "shiny"],
28+
... dtype=object)
29+
>>> bpd.crosstab(a, [b, c], rownames=['a'], colnames=['b', 'c'])
30+
b one two
31+
c dull shiny dull shiny
32+
a
33+
bar 1 2 1 0
34+
foo 2 2 1 2
35+
<BLANKLINE>
36+
[2 rows x 4 columns]
37+
38+
Args:
39+
index (array-like, Series, or list of arrays/Series):
40+
Values to group by in the rows.
41+
columns (array-like, Series, or list of arrays/Series):
42+
Values to group by in the columns.
43+
values (array-like, optional):
44+
Array of values to aggregate according to the factors.
45+
Requires `aggfunc` be specified.
46+
rownames (sequence, default None):
47+
If passed, must match number of row arrays passed.
48+
colnames (sequence, default None):
49+
If passed, must match number of column arrays passed.
50+
aggfunc (function, optional):
51+
If specified, requires `values` be specified as well.
52+
53+
Returns:
54+
DataFrame:
55+
Cross tabulation of the data.
56+
"""
57+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)

0 commit comments

Comments
 (0)