|
6 | 6 |
|
7 | 7 | import matplotlib.pyplot as plt |
8 | 8 | import pandas as pd |
| 9 | +import seaborn as sns |
9 | 10 |
|
10 | 11 | from safeds.data.image.containers import Image |
11 | 12 | from safeds.data.tabular.containers import Column, Row, Table, TaggedTable |
@@ -36,7 +37,7 @@ def _from_tagged_table( |
36 | 37 |
|
37 | 38 | Parameters |
38 | 39 | ---------- |
39 | | - table : TaggedTable |
| 40 | + tagged_table: TaggedTable |
40 | 41 | The tagged table. |
41 | 42 | time_name: str |
42 | 43 | Name of the time column. |
@@ -906,3 +907,150 @@ def plot_lagplot(self, lag: int) -> Image: |
906 | 907 | plt.close() # Prevents the figure from being displayed directly |
907 | 908 | buffer.seek(0) |
908 | 909 | return Image.from_bytes(buffer.read()) |
| 910 | + |
| 911 | + def plot_lineplot(self, x_column_name: str | None = None, y_column_name: str | None = None) -> Image: |
| 912 | + """ |
| 913 | +
|
| 914 | + Plot the time series target or the given column(s) as line plot. |
| 915 | +
|
| 916 | + The function will take the time column as the default value for y_column_name and the target column as the |
| 917 | + default value for x_column_name. |
| 918 | +
|
| 919 | + Parameters |
| 920 | + ---------- |
| 921 | + x_column_name: |
| 922 | + The column name of the column to be plotted on the x-Axis, default is the time column. |
| 923 | + y_column_name: |
| 924 | + The column name of the column to be plotted on the y-Axis, default is the target column. |
| 925 | +
|
| 926 | + Returns |
| 927 | + ------- |
| 928 | + plot: |
| 929 | + The plot as an image. |
| 930 | +
|
| 931 | + Raises |
| 932 | + ------ |
| 933 | + NonNumericColumnError |
| 934 | + If the time series given columns contain non-numerical values. |
| 935 | +
|
| 936 | + UnknownColumnNameError |
| 937 | + If one of the given names does not exist in the table |
| 938 | +
|
| 939 | + Examples |
| 940 | + -------- |
| 941 | + >>> from safeds.data.tabular.containers import TimeSeries |
| 942 | + >>> table = TimeSeries({"time":[1, 2], "target": [3, 4], "feature":[2,2]}, target_name= "target", time_name="time", feature_names=["feature"], ) |
| 943 | + >>> image = table.plot_lineplot() |
| 944 | +
|
| 945 | + """ |
| 946 | + self._data.index.name = "index" |
| 947 | + if x_column_name is not None and not self.get_column(x_column_name).type.is_numeric(): |
| 948 | + raise NonNumericColumnError("The time series plotted column contains non-numerical columns.") |
| 949 | + |
| 950 | + if y_column_name is None: |
| 951 | + y_column_name = self.target.name |
| 952 | + |
| 953 | + elif y_column_name not in self._data.columns: |
| 954 | + raise UnknownColumnNameError([y_column_name]) |
| 955 | + |
| 956 | + if x_column_name is None: |
| 957 | + x_column_name = self.time.name |
| 958 | + |
| 959 | + if not self.get_column(y_column_name).type.is_numeric(): |
| 960 | + raise NonNumericColumnError("The time series plotted column contains non-numerical columns.") |
| 961 | + |
| 962 | + fig = plt.figure() |
| 963 | + ax = sns.lineplot( |
| 964 | + data=self._data, |
| 965 | + x=x_column_name, |
| 966 | + y=y_column_name, |
| 967 | + ) |
| 968 | + ax.set(xlabel=x_column_name, ylabel=y_column_name) |
| 969 | + ax.set_xticks(ax.get_xticks()) |
| 970 | + ax.set_xticklabels( |
| 971 | + ax.get_xticklabels(), |
| 972 | + rotation=45, |
| 973 | + horizontalalignment="right", |
| 974 | + ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels |
| 975 | + plt.tight_layout() |
| 976 | + |
| 977 | + buffer = io.BytesIO() |
| 978 | + fig.savefig(buffer, format="png") |
| 979 | + plt.close() # Prevents the figure from being displayed directly |
| 980 | + buffer.seek(0) |
| 981 | + self._data = self._data.reset_index() |
| 982 | + return Image.from_bytes(buffer.read()) |
| 983 | + |
| 984 | + def plot_scatterplot( |
| 985 | + self, |
| 986 | + x_column_name: str | None = None, |
| 987 | + y_column_name: str | None = None, |
| 988 | + ) -> Image: |
| 989 | + """ |
| 990 | + Plot the time series target or the given column(s) as scatter plot. |
| 991 | +
|
| 992 | + The function will take the time column as the default value for x_column_name and the target column as the |
| 993 | + default value for y_column_name. |
| 994 | +
|
| 995 | + Parameters |
| 996 | + ---------- |
| 997 | + x_column_name: |
| 998 | + The column name of the column to be plotted on the x-Axis. |
| 999 | + y_column_name: |
| 1000 | + The column name of the column to be plotted on the y-Axis. |
| 1001 | +
|
| 1002 | + Returns |
| 1003 | + ------- |
| 1004 | + plot: |
| 1005 | + The plot as an image. |
| 1006 | +
|
| 1007 | + Raises |
| 1008 | + ------ |
| 1009 | + NonNumericColumnError |
| 1010 | + If the time series given columns contain non-numerical values. |
| 1011 | +
|
| 1012 | + UnknownColumnNameError |
| 1013 | + If one of the given names does not exist in the table |
| 1014 | +
|
| 1015 | + Examples |
| 1016 | + -------- |
| 1017 | + >>> from safeds.data.tabular.containers import TimeSeries |
| 1018 | + >>> table = TimeSeries({"time":[1, 2], "target": [3, 4], "feature":[2,2]}, target_name= "target", time_name="time", feature_names=["feature"], ) |
| 1019 | + >>> image = table.plot_scatterplot() |
| 1020 | +
|
| 1021 | + """ |
| 1022 | + self._data.index.name = "index" |
| 1023 | + if x_column_name is not None and not self.get_column(x_column_name).type.is_numeric(): |
| 1024 | + raise NonNumericColumnError("The time series plotted column contains non-numerical columns.") |
| 1025 | + |
| 1026 | + if y_column_name is None: |
| 1027 | + y_column_name = self.target.name |
| 1028 | + elif y_column_name not in self._data.columns: |
| 1029 | + raise UnknownColumnNameError([y_column_name]) |
| 1030 | + if x_column_name is None: |
| 1031 | + x_column_name = self.time.name |
| 1032 | + |
| 1033 | + if not self.get_column(y_column_name).type.is_numeric(): |
| 1034 | + raise NonNumericColumnError("The time series plotted column contains non-numerical columns.") |
| 1035 | + |
| 1036 | + fig = plt.figure() |
| 1037 | + ax = sns.scatterplot( |
| 1038 | + data=self._data, |
| 1039 | + x=x_column_name, |
| 1040 | + y=y_column_name, |
| 1041 | + ) |
| 1042 | + ax.set(xlabel=x_column_name, ylabel=y_column_name) |
| 1043 | + ax.set_xticks(ax.get_xticks()) |
| 1044 | + ax.set_xticklabels( |
| 1045 | + ax.get_xticklabels(), |
| 1046 | + rotation=45, |
| 1047 | + horizontalalignment="right", |
| 1048 | + ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels |
| 1049 | + plt.tight_layout() |
| 1050 | + |
| 1051 | + buffer = io.BytesIO() |
| 1052 | + fig.savefig(buffer, format="png") |
| 1053 | + plt.close() # Prevents the figure from being displayed directly |
| 1054 | + buffer.seek(0) |
| 1055 | + self._data = self._data.reset_index() |
| 1056 | + return Image.from_bytes(buffer.read()) |
0 commit comments