diff --git a/2025/solid/class_based_report.py b/2025/solid/class_based_report.py index 86ff605..12bf21e 100644 --- a/2025/solid/class_based_report.py +++ b/2025/solid/class_based_report.py @@ -19,71 +19,84 @@ def read(self, file: str) -> pd.DataFrame: ... class CsvSalesReader: - def read(self, file: str) -> pd.DataFrame: - return pd.read_csv(file, parse_dates=["date"]) + def __init__(self, file: str): + self.file = file + + def read(self) -> pd.DataFrame: + return pd.read_csv(self.file, parse_dates=["date"]) class DateRangeFilter: + def __init__(self, start: datetime, end: datetime): + self.start = start + self.end = end + def apply( - self, df: pd.DataFrame, start: datetime | None, end: datetime | None + self, df: pd.DataFrame ) -> pd.DataFrame: - if start: - df = df[df["date"] >= pd.Timestamp(start)] - if end: - df = df[df["date"] <= pd.Timestamp(end)] + if self.start: + df = df[df["date"] >= pd.Timestamp(self.start)] + if self.end: + df = df[df["date"] <= pd.Timestamp(self.end)] return df -class Metric(Protocol): - def compute(self, df: pd.DataFrame) -> dict[str, object]: ... +class Report(Protocol): + def report(self, df: pd.DataFrame=None) -> dict[str, object]: ... -class CustomerCountMetric: - def compute(self, df: pd.DataFrame) -> dict[str, object]: +class CustomerCountReport: + def report(self, df: pd.DataFrame) -> dict[str, object]: return {"number_of_customers": df["name"].nunique()} -class AverageOrderValueMetric: - def compute(self, df: pd.DataFrame) -> dict[str, object]: +class AverageOrderValueReport: + def report(self, df: pd.DataFrame) -> dict[str, object]: sales = df[df["price"] > 0]["price"] avg = sales.mean() if not sales.empty else 0.0 return {"average_order_value (pre-tax)": round(avg, 2)} -class ReturnPercentageMetric: - def compute(self, df: pd.DataFrame) -> dict[str, object]: +class ReturnPercentageReport: + def report(self, df: pd.DataFrame) -> dict[str, object]: returns = df[df["price"] < 0] pct = (len(returns) / len(df)) * 100 if len(df) > 0 else 0 return {"percentage_of_returns": round(pct, 2)} -class TotalSalesMetric: - def compute(self, df: pd.DataFrame) -> dict[str, object]: +class TotalSalesReport: + def report(self, df: pd.DataFrame) -> dict[str, object]: return {"total_sales_in_period (pre-tax)": round(df["price"].sum(), 2)} +class DateRangeReport: + def __init__(self, start_date: datetime, end_date: datetime): + self.start_date = start_date + self.end_date = end_date + + def report(self, df: pd.DataFrame) -> dict[str, object]: + return { + "report_start": self.start_date.strftime("%Y-%m-%d") if self.start_date else "N/A", + "report_end": self.end_date.strftime("%Y-%m-%d") if self.end_date else "N/A" + } + + class SalesReportGenerator: def __init__( - self, reader: SalesReader, filterer: DateRangeFilter, metrics: list[Metric] + self, reader: SalesReader, filterer: DateRangeFilter, Reports: list[Report] ): self.reader = reader self.filterer = filterer - self.metrics = metrics + self.Reports = Reports - def generate(self, config: ReportConfig) -> dict[str, object]: - df = self.reader.read(config.input_file) - df = self.filterer.apply(df, config.start_date, config.end_date) + def generate(self) -> dict[str, object]: + df = self.reader.read() + df = self.filterer.apply(df) result = {} - for metric in self.metrics: - result.update(metric.compute(df)) - - result["report_start"] = ( - config.start_date.strftime("%Y-%m-%d") if config.start_date else "N/A" - ) - result["report_end"] = ( - config.end_date.strftime("%Y-%m-%d") if config.end_date else "N/A" - ) + for Report in self.Reports: + result.update(Report.report(df)) + return result @@ -101,17 +114,18 @@ def main() -> None: end_date=datetime(2024, 12, 31), ) - reader = CsvSalesReader() - filterer = DateRangeFilter() - metrics: list[Metric] = [ - CustomerCountMetric(), - AverageOrderValueMetric(), - ReturnPercentageMetric(), - TotalSalesMetric(), + reader = CsvSalesReader(file=config.input_file) + filterer = DateRangeFilter(config.start_date, config.end_date) + Reports: list[Report] = [ + CustomerCountReport(), + AverageOrderValueReport(), + ReturnPercentageReport(), + TotalSalesReport(), + DateRangeReport(config.start_date, config.end_date), ] - generator = SalesReportGenerator(reader, filterer, metrics) - report = generator.generate(config) + generator = SalesReportGenerator(reader, filterer, Reports) + report = generator.generate() writer = JSONReportWriter() writer.write(report, config.output_file)