Skip to content

Commit dbdf11e

Browse files
Gerhardsa0megalinter-botlars-reimann
authored
feat: added normal plot for time series (#550)
Closes #549 ### Summary of Changes Add more plot for time series: * ```TimeSeries.plot_lineplot``` * ```TimeSeries.plot_scatterplot``` --------- Co-authored-by: megalinter-bot <[email protected]> Co-authored-by: Lars Reimann <[email protected]>
1 parent 3415045 commit dbdf11e

File tree

13 files changed

+679
-1
lines changed

13 files changed

+679
-1
lines changed

src/safeds/data/tabular/containers/_time_series.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import matplotlib.pyplot as plt
88
import pandas as pd
9+
import seaborn as sns
910

1011
from safeds.data.image.containers import Image
1112
from safeds.data.tabular.containers import Column, Row, Table, TaggedTable
@@ -36,7 +37,7 @@ def _from_tagged_table(
3637
3738
Parameters
3839
----------
39-
table : TaggedTable
40+
tagged_table: TaggedTable
4041
The tagged table.
4142
time_name: str
4243
Name of the time column.
@@ -906,3 +907,150 @@ def plot_lagplot(self, lag: int) -> Image:
906907
plt.close() # Prevents the figure from being displayed directly
907908
buffer.seek(0)
908909
return Image.from_bytes(buffer.read())
910+
911+
def plot_lineplot(self, x_column_name: str | None = None, y_column_name: str | None = None) -> Image:
912+
"""
913+
914+
Plot the time series target or the given column(s) as line plot.
915+
916+
The function will take the time column as the default value for y_column_name and the target column as the
917+
default value for x_column_name.
918+
919+
Parameters
920+
----------
921+
x_column_name:
922+
The column name of the column to be plotted on the x-Axis, default is the time column.
923+
y_column_name:
924+
The column name of the column to be plotted on the y-Axis, default is the target column.
925+
926+
Returns
927+
-------
928+
plot:
929+
The plot as an image.
930+
931+
Raises
932+
------
933+
NonNumericColumnError
934+
If the time series given columns contain non-numerical values.
935+
936+
UnknownColumnNameError
937+
If one of the given names does not exist in the table
938+
939+
Examples
940+
--------
941+
>>> from safeds.data.tabular.containers import TimeSeries
942+
>>> table = TimeSeries({"time":[1, 2], "target": [3, 4], "feature":[2,2]}, target_name= "target", time_name="time", feature_names=["feature"], )
943+
>>> image = table.plot_lineplot()
944+
945+
"""
946+
self._data.index.name = "index"
947+
if x_column_name is not None and not self.get_column(x_column_name).type.is_numeric():
948+
raise NonNumericColumnError("The time series plotted column contains non-numerical columns.")
949+
950+
if y_column_name is None:
951+
y_column_name = self.target.name
952+
953+
elif y_column_name not in self._data.columns:
954+
raise UnknownColumnNameError([y_column_name])
955+
956+
if x_column_name is None:
957+
x_column_name = self.time.name
958+
959+
if not self.get_column(y_column_name).type.is_numeric():
960+
raise NonNumericColumnError("The time series plotted column contains non-numerical columns.")
961+
962+
fig = plt.figure()
963+
ax = sns.lineplot(
964+
data=self._data,
965+
x=x_column_name,
966+
y=y_column_name,
967+
)
968+
ax.set(xlabel=x_column_name, ylabel=y_column_name)
969+
ax.set_xticks(ax.get_xticks())
970+
ax.set_xticklabels(
971+
ax.get_xticklabels(),
972+
rotation=45,
973+
horizontalalignment="right",
974+
) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels
975+
plt.tight_layout()
976+
977+
buffer = io.BytesIO()
978+
fig.savefig(buffer, format="png")
979+
plt.close() # Prevents the figure from being displayed directly
980+
buffer.seek(0)
981+
self._data = self._data.reset_index()
982+
return Image.from_bytes(buffer.read())
983+
984+
def plot_scatterplot(
985+
self,
986+
x_column_name: str | None = None,
987+
y_column_name: str | None = None,
988+
) -> Image:
989+
"""
990+
Plot the time series target or the given column(s) as scatter plot.
991+
992+
The function will take the time column as the default value for x_column_name and the target column as the
993+
default value for y_column_name.
994+
995+
Parameters
996+
----------
997+
x_column_name:
998+
The column name of the column to be plotted on the x-Axis.
999+
y_column_name:
1000+
The column name of the column to be plotted on the y-Axis.
1001+
1002+
Returns
1003+
-------
1004+
plot:
1005+
The plot as an image.
1006+
1007+
Raises
1008+
------
1009+
NonNumericColumnError
1010+
If the time series given columns contain non-numerical values.
1011+
1012+
UnknownColumnNameError
1013+
If one of the given names does not exist in the table
1014+
1015+
Examples
1016+
--------
1017+
>>> from safeds.data.tabular.containers import TimeSeries
1018+
>>> table = TimeSeries({"time":[1, 2], "target": [3, 4], "feature":[2,2]}, target_name= "target", time_name="time", feature_names=["feature"], )
1019+
>>> image = table.plot_scatterplot()
1020+
1021+
"""
1022+
self._data.index.name = "index"
1023+
if x_column_name is not None and not self.get_column(x_column_name).type.is_numeric():
1024+
raise NonNumericColumnError("The time series plotted column contains non-numerical columns.")
1025+
1026+
if y_column_name is None:
1027+
y_column_name = self.target.name
1028+
elif y_column_name not in self._data.columns:
1029+
raise UnknownColumnNameError([y_column_name])
1030+
if x_column_name is None:
1031+
x_column_name = self.time.name
1032+
1033+
if not self.get_column(y_column_name).type.is_numeric():
1034+
raise NonNumericColumnError("The time series plotted column contains non-numerical columns.")
1035+
1036+
fig = plt.figure()
1037+
ax = sns.scatterplot(
1038+
data=self._data,
1039+
x=x_column_name,
1040+
y=y_column_name,
1041+
)
1042+
ax.set(xlabel=x_column_name, ylabel=y_column_name)
1043+
ax.set_xticks(ax.get_xticks())
1044+
ax.set_xticklabels(
1045+
ax.get_xticklabels(),
1046+
rotation=45,
1047+
horizontalalignment="right",
1048+
) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels
1049+
plt.tight_layout()
1050+
1051+
buffer = io.BytesIO()
1052+
fig.savefig(buffer, format="png")
1053+
plt.close() # Prevents the figure from being displayed directly
1054+
buffer.seek(0)
1055+
self._data = self._data.reset_index()
1056+
return Image.from_bytes(buffer.read())
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading

0 commit comments

Comments
 (0)