Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions google/cloud/bigquery/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import google.cloud._helpers # type: ignore
from google.cloud.bigquery import _helpers
from google.cloud.bigquery import standard_sql
from google.cloud.bigquery.encryption_configuration import EncryptionConfiguration


Expand Down Expand Up @@ -171,26 +172,32 @@ def training_runs(self) -> Sequence[Dict[str, Any]]:
)

@property
def feature_columns(self) -> Sequence[Dict[str, Any]]:
def feature_columns(self) -> Sequence[standard_sql.StandardSqlField]:
"""Input feature columns that were used to train this model.

Read-only.
"""
return typing.cast(
resource: Sequence[Dict[str, Any]] = typing.cast(
Sequence[Dict[str, Any]], self._properties.get("featureColumns", [])
)
return [
standard_sql.StandardSqlField.from_api_repr(column) for column in resource
]

@property
def label_columns(self) -> Sequence[Dict[str, Any]]:
def label_columns(self) -> Sequence[standard_sql.StandardSqlField]:
"""Label columns that were used to train this model.

The output of the model will have a ``predicted_`` prefix to these columns.

Read-only.
"""
return typing.cast(
resource: Sequence[Dict[str, Any]] = typing.cast(
Sequence[Dict[str, Any]], self._properties.get("labelColumns", [])
)
return [
standard_sql.StandardSqlField.from_api_repr(column) for column in resource
]

@property
def best_trial_id(self) -> Optional[int]:
Expand Down
40 changes: 40 additions & 0 deletions tests/unit/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,46 @@ def test_build_resource(object_under_test, resource, filter_fields, expected):
assert got == expected


def test_feature_columns(object_under_test):
from google.cloud.bigquery import standard_sql

object_under_test._properties["featureColumns"] = [
{"name": "col_1", "type": {"typeKind": "STRING"}},
{"name": "col_2", "type": {"typeKind": "FLOAT64"}},
]
expected = [
standard_sql.StandardSqlField(
"col_1",
standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.STRING),
),
standard_sql.StandardSqlField(
"col_2",
standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.FLOAT64),
),
]
assert object_under_test.feature_columns == expected


def test_label_columns(object_under_test):
from google.cloud.bigquery import standard_sql

object_under_test._properties["labelColumns"] = [
{"name": "col_1", "type": {"typeKind": "STRING"}},
{"name": "col_2", "type": {"typeKind": "FLOAT64"}},
]
expected = [
standard_sql.StandardSqlField(
"col_1",
standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.STRING),
),
standard_sql.StandardSqlField(
"col_2",
standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.FLOAT64),
),
]
assert object_under_test.label_columns == expected


def test_set_description(object_under_test):
assert not object_under_test.description
object_under_test.description = "A model description."
Expand Down