Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 241 additions & 0 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,247 @@ def test_parse_datatype_string(self):
)
self.assertEqual(VariantType(), _parse_datatype_string("variant"))

def test_tree_string(self):
schema1 = DataType.fromDDL("c1 INT, c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>")

self.assertEqual(
schema1.treeString().split("\n"),
[
"root",
" |-- c1: integer (nullable = true)",
" |-- c2: struct (nullable = true)",
" | |-- c3: integer (nullable = true)",
" | |-- c4: struct (nullable = true)",
" | | |-- c5: integer (nullable = true)",
" | | |-- c6: integer (nullable = true)",
"",
],
)
self.assertEqual(
schema1.treeString(-1).split("\n"),
[
"root",
" |-- c1: integer (nullable = true)",
" |-- c2: struct (nullable = true)",
" | |-- c3: integer (nullable = true)",
" | |-- c4: struct (nullable = true)",
" | | |-- c5: integer (nullable = true)",
" | | |-- c6: integer (nullable = true)",
"",
],
)
self.assertEqual(
schema1.treeString(0).split("\n"),
[
"root",
" |-- c1: integer (nullable = true)",
" |-- c2: struct (nullable = true)",
" | |-- c3: integer (nullable = true)",
" | |-- c4: struct (nullable = true)",
" | | |-- c5: integer (nullable = true)",
" | | |-- c6: integer (nullable = true)",
"",
],
)
self.assertEqual(
schema1.treeString(1).split("\n"),
[
"root",
" |-- c1: integer (nullable = true)",
" |-- c2: struct (nullable = true)",
"",
],
)
self.assertEqual(
schema1.treeString(2).split("\n"),
[
"root",
" |-- c1: integer (nullable = true)",
" |-- c2: struct (nullable = true)",
" | |-- c3: integer (nullable = true)",
" | |-- c4: struct (nullable = true)",
"",
],
)
self.assertEqual(
schema1.treeString(3).split("\n"),
[
"root",
" |-- c1: integer (nullable = true)",
" |-- c2: struct (nullable = true)",
" | |-- c3: integer (nullable = true)",
" | |-- c4: struct (nullable = true)",
" | | |-- c5: integer (nullable = true)",
" | | |-- c6: integer (nullable = true)",
"",
],
)
self.assertEqual(
schema1.treeString(4).split("\n"),
[
"root",
" |-- c1: integer (nullable = true)",
" |-- c2: struct (nullable = true)",
" | |-- c3: integer (nullable = true)",
" | |-- c4: struct (nullable = true)",
" | | |-- c5: integer (nullable = true)",
" | | |-- c6: integer (nullable = true)",
"",
],
)

schema2 = DataType.fromDDL(
"c1 INT, c2 ARRAY<STRUCT<c3: INT>>, c4 STRUCT<c5: INT, c6: ARRAY<ARRAY<INT>>>"
)
self.assertEqual(
schema2.treeString(0).split("\n"),
[
"root",
" |-- c1: integer (nullable = true)",
" |-- c2: array (nullable = true)",
" | |-- element: struct (containsNull = true)",
" | | |-- c3: integer (nullable = true)",
" |-- c4: struct (nullable = true)",
" | |-- c5: integer (nullable = true)",
" | |-- c6: array (nullable = true)",
" | | |-- element: array (containsNull = true)",
" | | | |-- element: integer (containsNull = true)",
"",
],
)
self.assertEqual(
schema2.treeString(1).split("\n"),
[
"root",
" |-- c1: integer (nullable = true)",
" |-- c2: array (nullable = true)",
" |-- c4: struct (nullable = true)",
"",
],
)
self.assertEqual(
schema2.treeString(2).split("\n"),
[
"root",
" |-- c1: integer (nullable = true)",
" |-- c2: array (nullable = true)",
" | |-- element: struct (containsNull = true)",
" |-- c4: struct (nullable = true)",
" | |-- c5: integer (nullable = true)",
" | |-- c6: array (nullable = true)",
"",
],
)
self.assertEqual(
schema2.treeString(3).split("\n"),
[
"root",
" |-- c1: integer (nullable = true)",
" |-- c2: array (nullable = true)",
" | |-- element: struct (containsNull = true)",
" | | |-- c3: integer (nullable = true)",
" |-- c4: struct (nullable = true)",
" | |-- c5: integer (nullable = true)",
" | |-- c6: array (nullable = true)",
" | | |-- element: array (containsNull = true)",
"",
],
)
self.assertEqual(
schema2.treeString(4).split("\n"),
[
"root",
" |-- c1: integer (nullable = true)",
" |-- c2: array (nullable = true)",
" | |-- element: struct (containsNull = true)",
" | | |-- c3: integer (nullable = true)",
" |-- c4: struct (nullable = true)",
" | |-- c5: integer (nullable = true)",
" | |-- c6: array (nullable = true)",
" | | |-- element: array (containsNull = true)",
" | | | |-- element: integer (containsNull = true)",
"",
],
)

schema3 = DataType.fromDDL(
"c1 MAP<INT, STRUCT<c2: MAP<INT, INT>>>, c3 STRUCT<c4: MAP<INT, MAP<INT, INT>>>"
)
self.assertEqual(
schema3.treeString(0).split("\n"),
[
"root",
" |-- c1: map (nullable = true)",
" | |-- key: integer",
" | |-- value: struct (valueContainsNull = true)",
" | | |-- c2: map (nullable = true)",
" | | | |-- key: integer",
" | | | |-- value: integer (valueContainsNull = true)",
" |-- c3: struct (nullable = true)",
" | |-- c4: map (nullable = true)",
" | | |-- key: integer",
" | | |-- value: map (valueContainsNull = true)",
" | | | |-- key: integer",
" | | | |-- value: integer (valueContainsNull = true)",
"",
],
)
self.assertEqual(
schema3.treeString(1).split("\n"),
[
"root",
" |-- c1: map (nullable = true)",
" |-- c3: struct (nullable = true)",
"",
],
)
self.assertEqual(
schema3.treeString(2).split("\n"),
[
"root",
" |-- c1: map (nullable = true)",
" | |-- key: integer",
" | |-- value: struct (valueContainsNull = true)",
" |-- c3: struct (nullable = true)",
" | |-- c4: map (nullable = true)",
"",
],
)
self.assertEqual(
schema3.treeString(3).split("\n"),
[
"root",
" |-- c1: map (nullable = true)",
" | |-- key: integer",
" | |-- value: struct (valueContainsNull = true)",
" | | |-- c2: map (nullable = true)",
" |-- c3: struct (nullable = true)",
" | |-- c4: map (nullable = true)",
" | | |-- key: integer",
" | | |-- value: map (valueContainsNull = true)",
"",
],
)
self.assertEqual(
schema3.treeString(4).split("\n"),
[
"root",
" |-- c1: map (nullable = true)",
" | |-- key: integer",
" | |-- value: struct (valueContainsNull = true)",
" | | |-- c2: map (nullable = true)",
" | | | |-- key: integer",
" | | | |-- value: integer (valueContainsNull = true)",
" |-- c3: struct (nullable = true)",
" | |-- c4: map (nullable = true)",
" | | |-- key: integer",
" | | |-- value: map (valueContainsNull = true)",
" | | | |-- key: integer",
" | | | |-- value: integer (valueContainsNull = true)",
"",
],
)

def test_metadata_null(self):
schema = StructType(
[
Expand Down
87 changes: 86 additions & 1 deletion python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@

from pyspark.util import is_remote_only
from pyspark.serializers import CloudPickleSerializer
from pyspark.sql.utils import has_numpy, get_active_spark_context
from pyspark.sql.utils import (
has_numpy,
get_active_spark_context,
escape_meta_characters,
StringConcat,
)
from pyspark.sql.variant_utils import VariantUtils
from pyspark.errors import (
PySparkNotImplementedError,
Expand Down Expand Up @@ -99,6 +104,8 @@
"VariantVal",
]

_JVM_INT_MAX: int = (1 << 31) - 1


class DataType:
"""Base class for data types."""
Expand Down Expand Up @@ -199,6 +206,17 @@ def fromDDL(cls, ddl: str) -> "DataType":
assert len(schema) == 1
return schema[0].dataType

@classmethod
def _data_type_build_formatted_string(
cls,
dataType: "DataType",
prefix: str,
stringConcat: StringConcat,
maxDepth: int,
) -> None:
if isinstance(dataType, (ArrayType, StructType, MapType)):
dataType._build_formatted_string(prefix, stringConcat, maxDepth - 1)


# This singleton pattern does not work with pickle, you will get
# another object after pickle and unpickle
Expand Down Expand Up @@ -734,6 +752,21 @@ def fromInternal(self, obj: List[Optional[T]]) -> List[Optional[T]]:
return obj
return obj and [self.elementType.fromInternal(v) for v in obj]

def _build_formatted_string(
self,
prefix: str,
stringConcat: StringConcat,
maxDepth: int = _JVM_INT_MAX,
) -> None:
if maxDepth > 0:
stringConcat.append(
f"{prefix}-- element: {self.elementType.typeName()} "
+ f"(containsNull = {str(self.containsNull).lower()})\n"
)
DataType._data_type_build_formatted_string(
self.elementType, f"{prefix} |", stringConcat, maxDepth
)


class MapType(DataType):
"""Map data type.
Expand Down Expand Up @@ -868,6 +901,25 @@ def fromInternal(self, obj: Dict[T, Optional[U]]) -> Dict[T, Optional[U]]:
(self.keyType.fromInternal(k), self.valueType.fromInternal(v)) for k, v in obj.items()
)

def _build_formatted_string(
self,
prefix: str,
stringConcat: StringConcat,
maxDepth: int = _JVM_INT_MAX,
) -> None:
if maxDepth > 0:
stringConcat.append(f"{prefix}-- key: {self.keyType.typeName()}\n")
DataType._data_type_build_formatted_string(
self.keyType, f"{prefix} |", stringConcat, maxDepth
)
stringConcat.append(
f"{prefix}-- value: {self.valueType.typeName()} "
+ f"(valueContainsNull = {str(self.valueContainsNull).lower()})\n"
)
DataType._data_type_build_formatted_string(
self.valueType, f"{prefix} |", stringConcat, maxDepth
)


class StructField(DataType):
"""A field in :class:`StructType`.
Expand Down Expand Up @@ -1016,6 +1068,21 @@ def typeName(self) -> str: # type: ignore[override]
message_parameters={},
)

def _build_formatted_string(
self,
prefix: str,
stringConcat: StringConcat,
maxDepth: int = _JVM_INT_MAX,
) -> None:
if maxDepth > 0:
stringConcat.append(
f"{prefix}-- {escape_meta_characters(self.name)}: {self.dataType.typeName()} "
+ f"(nullable = {str(self.nullable).lower()})\n"
)
DataType._data_type_build_formatted_string(
self.dataType, f"{prefix} |", stringConcat, maxDepth
)


class StructType(DataType):
"""Struct type, consisting of a list of :class:`StructField`.
Expand Down Expand Up @@ -1436,6 +1503,24 @@ def fromInternal(self, obj: Tuple) -> "Row":
values = obj
return _create_row(self.names, values)

def _build_formatted_string(
self,
prefix: str,
stringConcat: StringConcat,
maxDepth: int = _JVM_INT_MAX,
) -> None:
for field in self.fields:
field._build_formatted_string(prefix, stringConcat, maxDepth)

def treeString(self, maxDepth: int = _JVM_INT_MAX) -> str:
stringConcat = StringConcat()
stringConcat.append("root\n")
prefix = " |"
depth = maxDepth if maxDepth > 0 else _JVM_INT_MAX
for field in self.fields:
field._build_formatted_string(prefix, stringConcat, depth)
return stringConcat.toString()


class VariantType(AtomicType):
"""
Expand Down
Loading