From 28b0d842a72b8e4794bfd2457d50c076cc746485 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Fri, 8 Aug 2025 15:55:29 +0100 Subject: [PATCH 1/7] Split readers into separate files --- src/muse/readers/__init__.py | 5 +- src/muse/readers/csv.py | 1475 -------------------------- src/muse/readers/csv/__init__.py | 67 ++ src/muse/readers/csv/agents.py | 104 ++ src/muse/readers/csv/assets.py | 58 + src/muse/readers/csv/commodities.py | 57 + src/muse/readers/csv/general.py | 49 + src/muse/readers/csv/helpers.py | 342 ++++++ src/muse/readers/csv/market.py | 126 +++ src/muse/readers/csv/presets.py | 109 ++ src/muse/readers/csv/regression.py | 185 ++++ src/muse/readers/csv/technologies.py | 368 +++++++ src/muse/readers/csv/trade.py | 106 ++ 13 files changed, 1574 insertions(+), 1477 deletions(-) delete mode 100644 src/muse/readers/csv.py create mode 100644 src/muse/readers/csv/__init__.py create mode 100644 src/muse/readers/csv/agents.py create mode 100644 src/muse/readers/csv/assets.py create mode 100644 src/muse/readers/csv/commodities.py create mode 100644 src/muse/readers/csv/general.py create mode 100644 src/muse/readers/csv/helpers.py create mode 100644 src/muse/readers/csv/market.py create mode 100644 src/muse/readers/csv/presets.py create mode 100644 src/muse/readers/csv/regression.py create mode 100644 src/muse/readers/csv/technologies.py create mode 100644 src/muse/readers/csv/trade.py diff --git a/src/muse/readers/__init__.py b/src/muse/readers/__init__.py index 16900043d..9c7edf403 100644 --- a/src/muse/readers/__init__.py +++ b/src/muse/readers/__init__.py @@ -1,8 +1,9 @@ """Aggregates methods to read data from file.""" from muse.defaults import DATA_DIRECTORY -from muse.readers.csv import * # noqa: F403 -from muse.readers.toml import read_settings # noqa: F401 + +from .csv import * # noqa: F403 +from .toml import read_settings # noqa: F401 DEFAULT_SETTINGS_PATH = DATA_DIRECTORY / "default_settings.toml" """Default settings path.""" diff --git a/src/muse/readers/csv.py b/src/muse/readers/csv.py deleted file mode 100644 index 9a5018f63..000000000 --- a/src/muse/readers/csv.py +++ /dev/null @@ -1,1475 +0,0 @@ -"""Ensemble of functions to read MUSE data. - -In general, there are three functions per input file: -`read_x`: This is the overall function that is called to read the data. It takes a - `Path` as input, and returns the relevant data structure (usually an xarray). The - process is generally broken down into two functions that are called by `read_x`: - -`read_x_csv`: This takes a path to a csv file as input and returns a pandas dataframe. - There are some consistency checks, such as checking data types and columns. There - is also some minor processing at this stage, such as standardising column names, - but no structural changes to the data. The general rule is that anything returned - by this function should still be valid as an input file if saved to csv. -`process_x`: This is where more major processing and reformatting of the data is done. - It takes the dataframe from `read_x_csv` and returns the final data structure - (usually an xarray). There are also some more checks (e.g. checking for nan - values). - -Most of the processing is shared by a few helper functions: -- read_csv: reads a csv file and returns a dataframe -- standardize_dataframe: standardizes the dataframe to a common format -- create_multiindex: creates a multiindex from a dataframe -- create_xarray_dataset: creates an xarray dataset from a dataframe - -A few other helpers perform common operations on xarrays: -- create_assets: creates assets from technologies -- check_commodities: checks commodities and fills missing values - -""" - -from __future__ import annotations - -__all__ = [ - "read_agent_parameters", - "read_attribute_table", - "read_csv", - "read_existing_trade", - "read_global_commodities", - "read_initial_capacity", - "read_initial_market", - "read_io_technodata", - "read_macro_drivers", - "read_presets", - "read_regression_parameters", - "read_technodata_timeslices", - "read_technodictionary", - "read_technologies", - "read_timeslice_shares", - "read_trade_technodata", -] - -from logging import getLogger -from pathlib import Path - -import pandas as pd -import xarray as xr - -from muse.utilities import camel_to_snake - -# Global mapping of column names to their standardized versions -# This is for backwards compatibility with old file formats -COLUMN_RENAMES = { - "process_name": "technology", - "process": "technology", - "sector_name": "sector", - "region_name": "region", - "time": "year", - "commodity_name": "commodity", - "comm_type": "commodity_type", - "commodity_price": "prices", - "units_commodity_price": "units_prices", - "enduse": "end_use", - "sn": "timeslice", - "commodity_emission_factor_CO2": "emmission_factor", - "utilisation_factor": "utilization_factor", - "objsort": "obj_sort", - "objsort1": "obj_sort1", - "objsort2": "obj_sort2", - "objsort3": "obj_sort3", - "time_slice": "timeslice", - "price": "prices", -} - -# Columns who's values should be converted from camelCase to snake_case -CAMEL_TO_SNAKE_COLUMNS = [ - "tech_type", - "commodity", - "commodity_type", - "agent_share", - "attribute", - "sector", - "region", - "parameter", -] - -# Global mapping of column names to their expected types -COLUMN_TYPES = { - "year": int, - "region": str, - "technology": str, - "commodity": str, - "sector": str, - "attribute": str, - "variable": str, - "timeslice": int, # For tables that require int timeslice instead of month etc. - "name": str, - "commodity_type": str, - "tech_type": str, - "type": str, - "function_type": str, - "level": str, - "search_rule": str, - "decision_method": str, - "quantity": float, - "share": float, - "coeff": str, - "value": float, - "utilization_factor": float, - "minimum_service_factor": float, - "maturity_threshold": float, - "spend_limit": float, - "prices": float, - "emmission_factor": float, -} - -DEFAULTS = { - "cap_par": 0, - "cap_exp": 1, - "fix_par": 0, - "fix_exp": 1, - "var_par": 0, - "var_exp": 1, - "interest_rate": 0, - "utilization_factor": 1, - "minimum_service_factor": 0, - "search_rule": "all", - "decision_method": "single", -} - - -def standardize_columns(data: pd.DataFrame) -> pd.DataFrame: - """Standardizes column names in a DataFrame. - - This function: - 1. Converts column names to snake_case - 2. Applies the global COLUMN_RENAMES mapping - 3. Preserves any columns not in the mapping - - Args: - data: DataFrame to standardize - - Returns: - DataFrame with standardized column names - """ - # Drop index column if present - if data.columns[0] == "" or data.columns[0].startswith("Unnamed"): - data = data.iloc[:, 1:] - - # Convert columns to snake_case - data = data.rename(columns=camel_to_snake) - - # Then apply global mapping - data = data.rename(columns=COLUMN_RENAMES) - - # Make sure there are no duplicate columns - if len(data.columns) != len(set(data.columns)): - raise ValueError(f"Duplicate columns in {data.columns}") - - return data - - -def create_multiindex( - data: pd.DataFrame, - index_columns: list[str], - index_names: list[str], - drop_columns: bool = True, -) -> pd.DataFrame: - """Creates a MultiIndex from specified columns. - - Args: - data: DataFrame to create index from - index_columns: List of column names to use for index - index_names: List of names for the index levels - drop_columns: Whether to drop the original columns - - Returns: - DataFrame with new MultiIndex - """ - index = pd.MultiIndex.from_arrays( - [data[col] for col in index_columns], names=index_names - ) - result = data.copy() - result.index = index - if drop_columns: - result = result.drop(columns=index_columns) - return result - - -def create_xarray_dataset( - data: pd.DataFrame, - disallow_nan: bool = True, -) -> xr.Dataset: - """Creates an xarray Dataset from a DataFrame with standardized options. - - Args: - data: DataFrame to convert - disallow_nan: Whether to raise an error if NaN values are found - - Returns: - xarray Dataset - """ - result = xr.Dataset.from_dataframe(data) - if disallow_nan: - nan_coords = get_nan_coordinates(result) - if nan_coords: - raise ValueError(f"Missing data for coordinates: {nan_coords}") - - if "year" in result.coords: - result = result.assign_coords(year=result.year.astype(int)) - result = result.sortby("year") - assert len(set(result.year.values)) == result.year.data.size # no duplicates - - return result - - -def get_nan_coordinates(dataset: xr.Dataset) -> list[tuple]: - """Get coordinates of a Dataset where any data variable has NaN values.""" - any_nan = sum(var.isnull() for var in dataset.data_vars.values()) - if any_nan.any(): - return any_nan.where(any_nan, drop=True).to_dataframe(name="").index.to_list() - return [] - - -def convert_column_types(data: pd.DataFrame) -> pd.DataFrame: - """Converts DataFrame columns to their expected types. - - Args: - data: DataFrame to convert - - Returns: - DataFrame with converted column types - """ - result = data.copy() - for column, expected_type in COLUMN_TYPES.items(): - if column in result.columns: - try: - if expected_type is int: - result[column] = pd.to_numeric(result[column], downcast="integer") - elif expected_type is float: - result[column] = pd.to_numeric(result[column]).astype(float) - elif expected_type is str: - result[column] = result[column].astype(str) - except (ValueError, TypeError) as e: - raise ValueError( - f"Could not convert column '{column}' to {expected_type.__name__}: {e}" # noqa: E501 - ) - return result - - -def standardize_dataframe( - data: pd.DataFrame, - required_columns: list[str] | None = None, - exclude_extra_columns: bool = False, -) -> pd.DataFrame: - """Standardizes a DataFrame to a common format. - - Args: - data: DataFrame to standardize - required_columns: List of column names that must be present (optional) - exclude_extra_columns: If True, exclude any columns not in required_columns list - (optional). This can be important if extra columns can mess up the resulting - xarray object. - - Returns: - DataFrame containing the standardized data - """ - if required_columns is None: - required_columns = [] - - # Standardize column names - data = standardize_columns(data) - - # Convert specified column values from camelCase to snake_case - for col in CAMEL_TO_SNAKE_COLUMNS: - if col in data.columns: - data[col] = data[col].apply(camel_to_snake) - - # Fill missing values with defaults - data = data.fillna(DEFAULTS) - for col, default in DEFAULTS.items(): - if col not in data.columns and col in required_columns: - data[col] = default - - # Check/convert data types - data = convert_column_types(data) - - # Validate required columns if provided - if required_columns: - missing_columns = [col for col in required_columns if col not in data.columns] - if missing_columns: - raise ValueError(f"Missing required columns: {missing_columns}") - - # Exclude extra columns if requested - if exclude_extra_columns: - data = data[list(required_columns)] - - return data - - -def read_csv( - path: Path, - float_precision: str = "high", - required_columns: list[str] | None = None, - exclude_extra_columns: bool = False, - msg: str | None = None, -) -> pd.DataFrame: - """Reads and standardizes a CSV file into a DataFrame. - - Args: - path: Path to the CSV file - float_precision: Precision to use when reading floats - required_columns: List of column names that must be present (optional) - exclude_extra_columns: If True, exclude any columns not in required_columns list - (optional). This can be important if extra columns can mess up the resulting - xarray object. - msg: Message to log (optional) - - Returns: - DataFrame containing the standardized data - """ - # Log message - if msg: - getLogger(__name__).info(msg) - - # Check if file exists - if not path.is_file(): - raise OSError(f"{path} does not exist.") - - # Check if there's a units row (in which case we need to skip it) - with open(path) as f: - next(f) # Skip header row - first_data_row = f.readline().strip() - skiprows = [1] if first_data_row.startswith("Unit") else None - - # Read the file - data = pd.read_csv( - path, - float_precision=float_precision, - low_memory=False, - skiprows=skiprows, - ) - - # Standardize the DataFrame - return standardize_dataframe( - data, - required_columns=required_columns, - exclude_extra_columns=exclude_extra_columns, - ) - - -def check_commodities( - data: xr.Dataset | xr.DataArray, fill_missing: bool = True, fill_value: float = 0 -) -> xr.Dataset | xr.DataArray: - """Validates and optionally fills missing commodities in data.""" - from muse.commodities import COMMODITIES - - # Make sure there are no commodities in data but not in global commodities - extra_commodities = [ - c for c in data.commodity.values if c not in COMMODITIES.commodity.values - ] - if extra_commodities: - raise ValueError( - "The following commodities were not found in global commodities file: " - f"{extra_commodities}" - ) - - # Add any missing commodities with fill_value - if fill_missing: - data = data.reindex( - commodity=COMMODITIES.commodity.values, fill_value=fill_value - ) - return data - - -def create_assets(data: xr.DataArray | xr.Dataset) -> xr.DataArray | xr.Dataset: - """Creates assets from technology data.""" - # Rename technology to asset - result = data.drop_vars("technology").rename(technology="asset") - result["technology"] = "asset", data.technology.values - - # Add installed year - result["installed"] = ("asset", [int(result.year.min())] * len(result.technology)) - return result - - -def read_technodictionary(path: Path) -> xr.Dataset: - """Reads and processes technodictionary data from a CSV file.""" - df = read_technodictionary_csv(path) - return process_technodictionary(df) - - -def read_technodictionary_csv(path: Path) -> pd.DataFrame: - """Reads and formats technodata into a DataFrame.""" - required_columns = { - "cap_exp", - "region", - "var_par", - "fix_exp", - "interest_rate", - "utilization_factor", - "minimum_service_factor", - "year", - "cap_par", - "var_exp", - "technology", - "technical_life", - "fix_par", - } - data = read_csv( - path, - required_columns=required_columns, - msg=f"Reading technodictionary from {path}.", - ) - - # Check for deprecated columns - if "fuel" in data.columns: - msg = ( - f"The 'fuel' column in {path} has been deprecated. " - "This information is now determined from CommIn files. " - "Please remove this column from your Technodata files." - ) - getLogger(__name__).warning(msg) - if "end_use" in data.columns: - msg = ( - f"The 'end_use' column in {path} has been deprecated. " - "This information is now determined from CommOut files. " - "Please remove this column from your Technodata files." - ) - getLogger(__name__).warning(msg) - if "scaling_size" in data.columns: - msg = ( - f"The 'scaling_size' column in {path} has been deprecated. " - "Please remove this column from your Technodata files." - ) - getLogger(__name__).warning(msg) - - return data - - -def process_technodictionary(data: pd.DataFrame) -> xr.Dataset: - """Processes technodictionary DataFrame into an xarray Dataset.""" - # Create multiindex for technology and region - data = create_multiindex( - data, - index_columns=["technology", "region", "year"], - index_names=["technology", "region", "year"], - drop_columns=True, - ) - - # Create dataset - result = create_xarray_dataset(data) - - # Handle tech_type if present - if "type" in result.variables: - result["tech_type"] = result.type.isel(region=0, year=0) - - return result - - -def read_technodata_timeslices(path: Path) -> xr.Dataset: - """Reads and processes technodata timeslices from a CSV file.""" - df = read_technodata_timeslices_csv(path) - return process_technodata_timeslices(df) - - -def read_technodata_timeslices_csv(path: Path) -> pd.DataFrame: - """Reads and formats technodata timeslices into a DataFrame.""" - from muse.timeslices import TIMESLICE - - timeslice_columns = set(TIMESLICE.coords["timeslice"].indexes["timeslice"].names) - required_columns = { - "utilization_factor", - "technology", - "minimum_service_factor", - "region", - "year", - } | timeslice_columns - return read_csv( - path, - required_columns=required_columns, - exclude_extra_columns=True, - msg=f"Reading technodata timeslices from {path}.", - ) - - -def process_technodata_timeslices(data: pd.DataFrame) -> xr.Dataset: - """Processes technodata timeslices DataFrame into an xarray Dataset.""" - from muse.timeslices import TIMESLICE, sort_timeslices - - # Create multiindex for all columns except factor columns - factor_columns = ["utilization_factor", "minimum_service_factor", "obj_sort"] - index_columns = [col for col in data.columns if col not in factor_columns] - data = create_multiindex( - data, - index_columns=index_columns, - index_names=index_columns, - drop_columns=True, - ) - - # Create dataset - result = create_xarray_dataset(data) - - # Stack timeslice levels (month, day, hour) into a single timeslice dimension - timeslice_levels = TIMESLICE.coords["timeslice"].indexes["timeslice"].names - if all(level in result.dims for level in timeslice_levels): - result = result.stack(timeslice=timeslice_levels) - return sort_timeslices(result) - - -def read_io_technodata(path: Path) -> xr.Dataset: - """Reads and processes input/output technodata from a CSV file.""" - df = read_io_technodata_csv(path) - return process_io_technodata(df) - - -def read_io_technodata_csv(path: Path) -> pd.DataFrame: - """Reads process inputs or outputs into a DataFrame.""" - data = read_csv( - path, - required_columns=["technology", "region", "year"], - msg=f"Reading IO technodata from {path}.", - ) - - # Unspecified Level values default to "fixed" - if "level" in data.columns: - data["level"] = data["level"].fillna("fixed") - else: - # Particularly relevant to outputs files where the Level column is omitted by - # default, as only "fixed" outputs are allowed. - data["level"] = "fixed" - - return data - - -def process_io_technodata(data: pd.DataFrame) -> xr.Dataset: - """Processes IO technodata DataFrame into an xarray Dataset.""" - from muse.commodities import COMMODITIES - - # Extract commodity columns - commodities = [c for c in data.columns if c in COMMODITIES.commodity.values] - - # Convert commodity columns to long format (i.e. single "commodity" column) - data = data.melt( - id_vars=["technology", "region", "year", "level"], - value_vars=commodities, - var_name="commodity", - value_name="value", - ) - - # Pivot data to create fixed and flexible columns - data = data.pivot( - index=["technology", "region", "year", "commodity"], - columns="level", - values="value", - ) - - # Create xarray dataset - result = create_xarray_dataset(data) - - # Fill in flexible data - if "flexible" in result.data_vars: - result["flexible"] = result.flexible.fillna(0) - else: - result["flexible"] = xr.zeros_like(result.fixed).rename("flexible") - - # Check commodities - result = check_commodities(result, fill_missing=True, fill_value=0) - return result - - -def read_technologies( - technodata_path: Path, - comm_out_path: Path, - comm_in_path: Path, - time_framework: list[int], - interpolation_mode: str = "linear", - technodata_timeslices_path: Path | None = None, -) -> xr.Dataset: - """Reads and processes technology data from multiple CSV files. - - Will also interpolate data to the time framework if provided. - - Args: - technodata_path: path to the technodata file - comm_out_path: path to the comm_out file - comm_in_path: path to the comm_in file - time_framework: list of years to interpolate data to - interpolation_mode: Interpolation mode to use - technodata_timeslices_path: path to the technodata_timeslices file - - Returns: - xr.Dataset: Dataset containing the processed technology data. Any fields - that differ by year will contain a "year" dimension interpolated to the - time framework. Other fields will not have a "year" dimension. - """ - # Read all data - technodata = read_technodictionary(technodata_path) - comm_out = read_io_technodata(comm_out_path) - comm_in = read_io_technodata(comm_in_path) - technodata_timeslices = ( - read_technodata_timeslices(technodata_timeslices_path) - if technodata_timeslices_path - else None - ) - - # Assemble xarray Dataset - return process_technologies( - technodata, - comm_out, - comm_in, - time_framework, - interpolation_mode, - technodata_timeslices, - ) - - -def process_technologies( - technodata: xr.Dataset, - comm_out: xr.Dataset, - comm_in: xr.Dataset, - time_framework: list[int], - interpolation_mode: str = "linear", - technodata_timeslices: xr.Dataset | None = None, -) -> xr.Dataset: - """Processes technology data DataFrames into an xarray Dataset.""" - from muse.commodities import COMMODITIES, CommodityUsage - from muse.timeslices import drop_timeslice - from muse.utilities import interpolate_technodata - - # Process inputs/outputs - ins = comm_in.rename(flexible="flexible_inputs", fixed="fixed_inputs") - outs = comm_out.rename(flexible="flexible_outputs", fixed="fixed_outputs") - - # Legacy: Remove flexible outputs - if not (outs["flexible_outputs"] == 0).all(): - raise ValueError( - "'flexible' outputs are not permitted. All outputs must be 'fixed'" - ) - outs = outs.drop_vars("flexible_outputs") - - # Collect all years from the time framework and data files - time_framework = list( - set(time_framework).union( - technodata.year.values.tolist(), - ins.year.values.tolist(), - outs.year.values.tolist(), - technodata_timeslices.year.values.tolist() if technodata_timeslices else [], - ) - ) - - # Interpolate data to match the time framework - technodata = interpolate_technodata(technodata, time_framework, interpolation_mode) - outs = interpolate_technodata(outs, time_framework, interpolation_mode) - ins = interpolate_technodata(ins, time_framework, interpolation_mode) - if technodata_timeslices: - technodata_timeslices = interpolate_technodata( - technodata_timeslices, time_framework, interpolation_mode - ) - - # Merge inputs/outputs with technodata - technodata = technodata.merge(outs).merge(ins) - - # Merge technodata_timeslices if provided. This will prioritise values defined in - # technodata_timeslices, and fallback to the non-timesliced technodata for any - # values that are not defined in technodata_timeslices. - if technodata_timeslices: - technodata["utilization_factor"] = ( - technodata_timeslices.utilization_factor.combine_first( - technodata.utilization_factor - ) - ) - technodata["minimum_service_factor"] = drop_timeslice( - technodata_timeslices.minimum_service_factor.combine_first( - technodata.minimum_service_factor - ) - ) - - # Check commodities - technodata = check_commodities(technodata, fill_missing=False) - - # Add info about commodities - technodata = technodata.merge(COMMODITIES.sel(commodity=technodata.commodity)) - - # Add commodity usage flags - technodata["comm_usage"] = ( - "commodity", - CommodityUsage.from_technologies(technodata).values, - ) - technodata = technodata.drop_vars("commodity_type") - - # Check utilization and minimum service factors - check_utilization_and_minimum_service_factors(technodata) - - return technodata - - -def read_initial_capacity(path: Path) -> xr.DataArray: - """Reads and processes initial capacity data from a CSV file.""" - df = read_initial_capacity_csv(path) - return process_initial_capacity(df) - - -def read_initial_capacity_csv(path: Path) -> pd.DataFrame: - """Reads and formats data about initial capacity into a DataFrame.""" - required_columns = { - "region", - "technology", - } - return read_csv( - path, - required_columns=required_columns, - msg=f"Reading initial capacity from {path}.", - ) - - -def process_initial_capacity(data: pd.DataFrame) -> xr.DataArray: - """Processes initial capacity DataFrame into an xarray DataArray.""" - # Drop unit column if present - if "unit" in data.columns: - data = data.drop(columns=["unit"]) - - # Select year columns - year_columns = [col for col in data.columns if col.isdigit()] - - # Convert year columns to long format (i.e. single "year" column) - data = data.melt( - id_vars=["technology", "region"], - value_vars=year_columns, - var_name="year", - value_name="value", - ) - - # Create multiindex for region, technology, and year - data = create_multiindex( - data, - index_columns=["technology", "region", "year"], - index_names=["technology", "region", "year"], - drop_columns=True, - ) - - # Create Dataarray - result = create_xarray_dataset(data).value.astype(float) - - # Create assets - result = create_assets(result) - return result - - -def read_global_commodities(path: Path) -> xr.Dataset: - """Reads and processes global commodities data from a CSV file.""" - df = read_global_commodities_csv(path) - return process_global_commodities(df) - - -def read_global_commodities_csv(path: Path) -> pd.DataFrame: - """Reads commodities information from input into a DataFrame.""" - # Due to legacy reasons, users can supply both Commodity and CommodityName columns - # In this case, we need to remove the Commodity column to avoid conflicts - # This is fine because Commodity just contains a long description that isn't needed - getLogger(__name__).info(f"Reading global commodities from {path}.") - df = pd.read_csv(path) - df = df.rename(columns=camel_to_snake) - if "commodity" in df.columns and "commodity_name" in df.columns: - df = df.drop(columns=["commodity"]) - - required_columns = { - "commodity", - "commodity_type", - } - data = standardize_dataframe( - df, - required_columns=required_columns, - ) - - # Raise warning if units are not defined - if "unit" not in data.columns: - msg = ( - "No units defined for commodities. Please define units for all commodities " - "in the global commodities file." - ) - getLogger(__name__).warning(msg) - - return data - - -def process_global_commodities(data: pd.DataFrame) -> xr.Dataset: - """Processes global commodities DataFrame into an xarray Dataset.""" - # Drop description column if present. It's useful to include in the file, but we - # don't need it for the simulation. - if "description" in data.columns: - data = data.drop(columns=["description"]) - - data.index = [u for u in data.commodity] - data = data.drop("commodity", axis=1) - data.index.name = "commodity" - return create_xarray_dataset(data) - - -def read_agent_parameters(path: Path) -> pd.DataFrame: - """Reads and processes agent parameters from a CSV file.""" - df = read_agent_parameters_csv(path) - return process_agent_parameters(df) - - -def read_agent_parameters_csv(path: Path) -> pd.DataFrame: - """Reads standard MUSE agent-declaration csv-files into a DataFrame.""" - required_columns = { - "search_rule", - "quantity", - "region", - "type", - "name", - "agent_share", - "decision_method", - } - data = read_csv( - path, - required_columns=required_columns, - msg=f"Reading agent parameters from {path}.", - ) - - # Check for deprecated retrofit agents - if "type" in data.columns: - retrofit_agents = data[data.type.str.lower().isin(["retrofit", "retro"])] - if not retrofit_agents.empty: - msg = ( - "Retrofit agents will be deprecated in a future release. " - "Please modify your model to use only agents of the 'New' type." - ) - getLogger(__name__).warning(msg) - - # Legacy: drop AgentNumber column - if "agent_number" in data.columns: - data = data.drop(["agent_number"], axis=1) - - # Check consistency of objectives data columns - objectives = [col for col in data.columns if col.startswith("objective")] - floats = [col for col in data.columns if col.startswith("obj_data")] - sorting = [col for col in data.columns if col.startswith("obj_sort")] - - if len(objectives) != len(floats) or len(objectives) != len(sorting): - raise ValueError( - "Agent objective, obj_data, and obj_sort columns are inconsistent in " - f"{path}" - ) - - return data - - -def process_agent_parameters(data: pd.DataFrame) -> list[dict]: - """Processes agent parameters DataFrame into a list of agent dictionaries.""" - result = [] - for _, row in data.iterrows(): - # Get objectives data - objectives = ( - row[[i.startswith("objective") for i in row.index]].dropna().to_list() - ) - sorting = row[[i.startswith("obj_sort") for i in row.index]].dropna().to_list() - floats = row[[i.startswith("obj_data") for i in row.index]].dropna().to_list() - - # Create decision parameters - decision_params = list(zip(objectives, sorting, floats)) - - agent_type = { - "new": "newcapa", - "newcapa": "newcapa", - "retrofit": "retrofit", - "retro": "retrofit", - "agent": "agent", - "default": "agent", - }[getattr(row, "type", "agent").lower()] - - # Create agent data dictionary - data = { - "name": row["name"], - "region": row.region, - "objectives": objectives, - "search_rules": row.search_rule, - "decision": {"name": row.decision_method, "parameters": decision_params}, - "agent_type": agent_type, - "quantity": row.quantity, - "share": row.agent_share, - } - - # Add optional parameters - if hasattr(row, "maturity_threshold"): - data["maturity_threshold"] = row.maturity_threshold - if hasattr(row, "spend_limit"): - data["spend_limit"] = row.spend_limit - - # Add agent data to result - result.append(data) - - return result - - -def read_initial_market( - projections_path: Path, - base_year_import_path: Path | None = None, - base_year_export_path: Path | None = None, - currency: str | None = None, -) -> xr.Dataset: - """Reads and processes initial market data. - - Args: - projections_path: path to the projections file - base_year_import_path: path to the base year import file (optional) - base_year_export_path: path to the base year export file (optional) - currency: currency string (e.g. "USD") - - Returns: - xr.Dataset: Dataset containing initial market data. - """ - # Read projections - projections_df = read_projections_csv(projections_path) - - # Read base year export (optional) - if base_year_export_path: - export_df = read_csv( - base_year_export_path, - msg=f"Reading base year export from {base_year_export_path}.", - ) - else: - export_df = None - - # Read base year import (optional) - if base_year_import_path: - import_df = read_csv( - base_year_import_path, - msg=f"Reading base year import from {base_year_import_path}.", - ) - else: - import_df = None - - # Assemble into xarray Dataset - result = process_initial_market(projections_df, import_df, export_df, currency) - return result - - -def read_projections_csv(path: Path) -> pd.DataFrame: - """Reads projections data from a CSV file.""" - required_columns = { - "region", - "attribute", - "year", - } - projections_df = read_csv( - path, required_columns=required_columns, msg=f"Reading projections from {path}." - ) - return projections_df - - -def process_initial_market( - projections_df: pd.DataFrame, - import_df: pd.DataFrame | None, - export_df: pd.DataFrame | None, - currency: str | None = None, -) -> xr.Dataset: - """Process market data DataFrames into an xarray Dataset. - - Args: - projections_df: DataFrame containing projections data - import_df: Optional DataFrame containing import data - export_df: Optional DataFrame containing export data - currency: Currency string (e.g. "USD") - """ - from muse.commodities import COMMODITIES - from muse.timeslices import broadcast_timeslice, distribute_timeslice - - # Process projections - projections = process_attribute_table(projections_df).commodity_price.astype( - "float64" - ) - - # Process optional trade data - if export_df is not None: - base_year_export = process_attribute_table(export_df).exports.astype("float64") - else: - base_year_export = xr.zeros_like(projections) - - if import_df is not None: - base_year_import = process_attribute_table(import_df).imports.astype("float64") - else: - base_year_import = xr.zeros_like(projections) - - # Distribute data over timeslices - projections = broadcast_timeslice(projections, level=None) - base_year_export = distribute_timeslice(base_year_export, level=None) - base_year_import = distribute_timeslice(base_year_import, level=None) - - # Assemble into xarray - result = xr.Dataset( - { - "prices": projections, - "exports": base_year_export, - "imports": base_year_import, - "static_trade": base_year_import - base_year_export, - } - ) - - # Check commodities - result = check_commodities(result, fill_missing=True, fill_value=0) - - # Add units_prices coordinate - # Only added if the currency is specified and commodity units are defined - if currency and "unit" in COMMODITIES.data_vars: - units_prices = [ - f"{currency}/{COMMODITIES.sel(commodity=c).unit.item()}" - for c in result.commodity.values - ] - result = result.assign_coords(units_prices=("commodity", units_prices)) - - return result - - -def read_attribute_table(path: Path) -> xr.Dataset: - """Reads and processes attribute table data from a CSV file.""" - df = read_attribute_table_csv(path) - return process_attribute_table(df) - - -def read_attribute_table_csv(path: Path) -> pd.DataFrame: - """Read a standard MUSE csv file for price projections into a DataFrame.""" - table = read_csv( - path, - required_columns=["region", "attribute", "year"], - msg=f"Reading attribute table from {path}.", - ) - return table - - -def process_attribute_table(data: pd.DataFrame) -> xr.Dataset: - """Process attribute table DataFrame into an xarray Dataset.""" - # Extract commodity columns - commodities = [ - col for col in data.columns if col not in ["region", "year", "attribute"] - ] - - # Convert commodity columns to long format (i.e. single "commodity" column) - data = data.melt( - id_vars=["region", "year", "attribute"], - value_vars=commodities, - var_name="commodity", - value_name="value", - ) - - # Pivot data over attributes - data = data.pivot( - index=["region", "year", "commodity"], - columns="attribute", - values="value", - ) - - # Create DataSet - result = create_xarray_dataset(data) - return result - - -def read_presets(presets_paths: Path) -> xr.Dataset: - """Reads and processes preset data from multiple CSV files. - - Accepts a path pattern for presets files, e.g. `Path("path/to/*Consumption.csv")`. - The file name of each file must contain a year (e.g. "2020Consumption.csv"). - """ - from glob import glob - from re import match - - # Find all files matching the path pattern - allfiles = [Path(p) for p in glob(str(presets_paths))] - if len(allfiles) == 0: - raise OSError(f"No files found with paths {presets_paths}") - - # Read all files - datas: dict[int, pd.DataFrame] = {} - for path in allfiles: - # Extract year from filename - reyear = match(r"\S*.(\d{4})\S*\.csv", path.name) - if reyear is None: - raise OSError(f"Unexpected filename {path.name}") - year = int(reyear.group(1)) - if year in datas: - raise OSError(f"Year f{year} was found twice") - - # Read data - data = read_presets_csv(path) - data["year"] = year - datas[year] = data - - # Process data - datas = process_presets(datas) - return datas - - -def read_presets_csv(path: Path) -> pd.DataFrame: - data = read_csv( - path, - required_columns=["region", "timeslice"], - msg=f"Reading presets from {path}.", - ) - - # Legacy: drop technology column and sum data (PR #448) - if "technology" in data.columns: - getLogger(__name__).warning( - f"The technology (or ProcessName) column in file {path} is " - "deprecated. Data has been summed across technologies, and this column " - "has been dropped." - ) - data = ( - data.drop(columns=["technology"]) - .groupby(["region", "timeslice"]) - .sum() - .reset_index() - ) - - return data - - -def process_presets(datas: dict[int, pd.DataFrame]) -> xr.Dataset: - """Processes preset DataFrames into an xarray Dataset.""" - from muse.commodities import COMMODITIES - from muse.timeslices import TIMESLICE - - # Combine into a single DataFrame - data = pd.concat(datas.values()) - - # Extract commodity columns - commodities = [c for c in data.columns if c in COMMODITIES.commodity.values] - - # Convert commodity columns to long format (i.e. single "commodity" column) - data = data.melt( - id_vars=["region", "year", "timeslice"], - value_vars=commodities, - var_name="commodity", - value_name="value", - ) - - # Create multiindex for region, year, timeslice and commodity - data = create_multiindex( - data, - index_columns=["region", "year", "timeslice", "commodity"], - index_names=["region", "year", "timeslice", "commodity"], - drop_columns=True, - ) - - # Create DataArray - result = create_xarray_dataset(data).value.astype(float) - - # Assign timeslices - result = result.assign_coords(timeslice=TIMESLICE.timeslice) - - # Check commodities - result = check_commodities(result, fill_missing=True, fill_value=0) - return result - - -def read_trade_technodata(path: Path) -> xr.Dataset: - """Reads and processes trade technodata from a CSV file.""" - df = read_trade_technodata_csv(path) - return process_trade_technodata(df) - - -def read_trade_technodata_csv(path: Path) -> pd.DataFrame: - required_columns = {"technology", "region", "parameter"} - return read_csv( - path, - required_columns=required_columns, - msg=f"Reading trade technodata from {path}.", - ) - - -def process_trade_technodata(data: pd.DataFrame) -> xr.Dataset: - # Drop unit column if present - if "unit" in data.columns: - data = data.drop(columns=["unit"]) - - # Select region columns - # TODO: this is a bit unsafe as user could supply other columns - regions = [ - col for col in data.columns if col not in ["technology", "region", "parameter"] - ] - - # Melt data over regions - data = data.melt( - id_vars=["technology", "region", "parameter"], - value_vars=regions, - var_name="dst_region", - value_name="value", - ) - - # Pivot data over parameters - data = data.pivot( - index=["technology", "region", "dst_region"], - columns="parameter", - values="value", - ) - - # Create DataSet - return create_xarray_dataset(data) - - -def read_existing_trade(path: Path) -> xr.DataArray: - """Reads and processes existing trade data from a CSV file.""" - df = read_existing_trade_csv(path) - return process_existing_trade(df) - - -def read_existing_trade_csv(path: Path) -> pd.DataFrame: - required_columns = { - "region", - "technology", - "year", - } - return read_csv( - path, - required_columns=required_columns, - msg=f"Reading existing trade from {path}.", - ) - - -def process_existing_trade(data: pd.DataFrame) -> xr.DataArray: - # Select region columns - # TODO: this is a bit unsafe as user could supply other columns - regions = [ - col for col in data.columns if col not in ["technology", "region", "year"] - ] - - # Melt data over regions - data = data.melt( - id_vars=["technology", "region", "year"], - value_vars=regions, - var_name="dst_region", - value_name="value", - ) - - # Create multiindex for region, dst_region, technology and year - data = create_multiindex( - data, - index_columns=["region", "dst_region", "technology", "year"], - index_names=["region", "dst_region", "technology", "year"], - drop_columns=True, - ) - - # Create DataArray - result = create_xarray_dataset(data).value.astype(float) - - # Create assets from technologies - result = create_assets(result) - return result - - -def read_timeslice_shares(path: Path) -> xr.DataArray: - """Reads and processes timeslice shares data from a CSV file.""" - df = read_timeslice_shares_csv(path) - return process_timeslice_shares(df) - - -def read_timeslice_shares_csv(path: Path) -> pd.DataFrame: - """Reads sliceshare information into a DataFrame.""" - data = read_csv( - path, - required_columns=["region", "timeslice"], - msg=f"Reading timeslice shares from {path}.", - ) - - return data - - -def process_timeslice_shares(data: pd.DataFrame) -> xr.DataArray: - """Processes timeslice shares DataFrame into an xarray DataArray.""" - from muse.commodities import COMMODITIES - from muse.timeslices import TIMESLICE - - # Extract commodity columns - commodities = [c for c in data.columns if c in COMMODITIES.commodity.values] - - # Convert commodity columns to long format (i.e. single "commodity" column) - data = data.melt( - id_vars=["region", "timeslice"], - value_vars=commodities, - var_name="commodity", - value_name="value", - ) - - # Create multiindex for region and timeslice - data = create_multiindex( - data, - index_columns=["region", "timeslice", "commodity"], - index_names=["region", "timeslice", "commodity"], - drop_columns=True, - ) - - # Create DataArray - result = create_xarray_dataset(data).value.astype(float) - - # Assign timeslices - result = result.assign_coords(timeslice=TIMESLICE.timeslice) - - # Check commodities - result = check_commodities(result, fill_missing=True, fill_value=0) - return result - - -def read_macro_drivers(path: Path) -> pd.DataFrame: - """Reads and processes macro drivers data from a CSV file.""" - df = read_macro_drivers_csv(path) - return process_macro_drivers(df) - - -def read_macro_drivers_csv(path: Path) -> pd.DataFrame: - """Reads a standard MUSE csv file for macro drivers into a DataFrame.""" - table = read_csv( - path, - required_columns=["region", "variable"], - msg=f"Reading macro drivers from {path}.", - ) - - # Validate required variables - required_variables = ["Population", "GDP|PPP"] - missing_variables = [ - var for var in required_variables if var not in table.variable.unique() - ] - if missing_variables: - raise ValueError(f"Missing required variables in {path}: {missing_variables}") - - return table - - -def process_macro_drivers(data: pd.DataFrame) -> xr.Dataset: - """Processes macro drivers DataFrame into an xarray Dataset.""" - # Drop unit column if present - if "unit" in data.columns: - data = data.drop(columns=["unit"]) - - # Select year columns - year_columns = [col for col in data.columns if col.isdigit()] - - # Convert year columns to long format (i.e. single "year" column) - data = data.melt( - id_vars=["variable", "region"], - value_vars=year_columns, - var_name="year", - value_name="value", - ) - - # Pivot data to create Population and GDP|PPP columns - data = data.pivot( - index=["region", "year"], - columns="variable", - values="value", - ) - - # Legacy: rename Population to population and GDP|PPP to gdp - if "Population" in data.columns: - data = data.rename(columns={"Population": "population"}) - if "GDP|PPP" in data.columns: - data = data.rename(columns={"GDP|PPP": "gdp"}) - - # Create DataSet - result = create_xarray_dataset(data) - return result - - -def read_regression_parameters(path: Path) -> xr.Dataset: - """Reads and processes regression parameters from a CSV file.""" - df = read_regression_parameters_csv(path) - return process_regression_parameters(df) - - -def read_regression_parameters_csv(path: Path) -> pd.DataFrame: - """Reads the regression parameters from a MUSE csv file into a DataFrame.""" - table = read_csv( - path, - required_columns=["region", "function_type", "coeff"], - msg=f"Reading regression parameters from {path}.", - ) - - # Legacy: warn about "sector" column - if "sector" in table.columns: - getLogger(__name__).warning( - f"The sector column (in file {path}) is deprecated. Please remove." - ) - - return table - - -def process_regression_parameters(data: pd.DataFrame) -> xr.Dataset: - """Processes regression parameters DataFrame into an xarray Dataset.""" - from muse.commodities import COMMODITIES - - # Extract commodity columns - commodities = [c for c in data.columns if c in COMMODITIES.commodity.values] - - # Melt to long format - melted = data.melt( - id_vars=["sector", "region", "function_type", "coeff"], - value_vars=commodities, - var_name="commodity", - value_name="value", - ) - - # Extract sector -> function_type mapping - sector_to_ftype = melted.drop_duplicates(["sector", "function_type"])[ - ["sector", "function_type"] - ].set_index("sector")["function_type"] - - # Pivot to create coefficient variables - pivoted = melted.pivot_table( - index=["sector", "region", "commodity"], columns="coeff", values="value" - ) - - # Create dataset and add function_type - result = create_xarray_dataset(pivoted) - result["function_type"] = xr.DataArray( - sector_to_ftype[result.sector.values].astype(object), - dims=["sector"], - name="function_type", - ) - - # Check commodities - result = check_commodities(result, fill_missing=True, fill_value=0) - return result - - -def check_utilization_and_minimum_service_factors(data: xr.Dataset) -> None: - """Check utilization and minimum service factors in an xarray dataset. - - Args: - data: xarray Dataset containing utilization_factor and minimum_service_factor - """ - if "utilization_factor" not in data.data_vars: - raise ValueError( - "A technology needs to have a utilization factor defined for every " - "timeslice." - ) - - # Check UF not all zero (sum across timeslice dimension if it exists) - if "timeslice" in data.dims: - utilization_sum = data.utilization_factor.sum(dim="timeslice") - else: - utilization_sum = data.utilization_factor - - if (utilization_sum == 0).any(): - raise ValueError( - "A technology can not have a utilization factor of 0 for every timeslice." - ) - - # Check UF in range - utilization = data.utilization_factor - if not ((utilization >= 0) & (utilization <= 1)).all(): - raise ValueError( - "Utilization factor values must all be between 0 and 1 inclusive." - ) - - # Check MSF in range - min_service_factor = data.minimum_service_factor - if not ((min_service_factor >= 0) & (min_service_factor <= 1)).all(): - raise ValueError( - "Minimum service factor values must all be between 0 and 1 inclusive." - ) - - # Check UF not below MSF - if (data.utilization_factor < data.minimum_service_factor).any(): - raise ValueError( - "Utilization factors must all be greater than or equal " - "to their corresponding minimum service factors." - ) diff --git a/src/muse/readers/csv/__init__.py b/src/muse/readers/csv/__init__.py new file mode 100644 index 000000000..d4a46985c --- /dev/null +++ b/src/muse/readers/csv/__init__.py @@ -0,0 +1,67 @@ +"""Ensemble of functions to read MUSE data. + +In general, there are three functions per input file: +`read_x`: This is the overall function that is called to read the data. It takes a + `Path` as input, and returns the relevant data structure (usually an xarray). The + process is generally broken down into two functions that are called by `read_x`: + +`read_x_csv`: This takes a path to a csv file as input and returns a pandas dataframe. + There are some consistency checks, such as checking data types and columns. There + is also some minor processing at this stage, such as standardising column names, + but no structural changes to the data. The general rule is that anything returned + by this function should still be valid as an input file if saved to csv. +`process_x`: This is where more major processing and reformatting of the data is done. + It takes the dataframe from `read_x_csv` and returns the final data structure + (usually an xarray). There are also some more checks (e.g. checking for nan + values). + +Most of the processing is shared by a few helper functions: +- read_csv: reads a csv file and returns a dataframe +- standardize_dataframe: standardizes the dataframe to a common format +- create_multiindex: creates a multiindex from a dataframe +- create_xarray_dataset: creates an xarray dataset from a dataframe + +A few other helpers perform common operations on xarrays: +- create_assets: creates assets from technologies +- check_commodities: checks commodities and fills missing values + +""" + +from .agents import read_agent_parameters +from .assets import read_initial_capacity +from .commodities import read_global_commodities +from .general import read_attribute_table +from .helpers import read_csv +from .market import read_initial_market +from .presets import read_presets +from .regression import ( + read_macro_drivers, + read_regression_parameters, + read_timeslice_shares, +) +from .technologies import ( + read_io_technodata, + read_technodata_timeslices, + read_technodictionary, + read_technologies, +) +from .trade import read_existing_trade, read_trade_technodata + +__all__ = [ + "read_agent_parameters", + "read_attribute_table", + "read_csv", + "read_existing_trade", + "read_global_commodities", + "read_initial_capacity", + "read_initial_market", + "read_io_technodata", + "read_macro_drivers", + "read_presets", + "read_regression_parameters", + "read_technodata_timeslices", + "read_technodictionary", + "read_technologies", + "read_timeslice_shares", + "read_trade_technodata", +] diff --git a/src/muse/readers/csv/agents.py b/src/muse/readers/csv/agents.py new file mode 100644 index 000000000..947e679de --- /dev/null +++ b/src/muse/readers/csv/agents.py @@ -0,0 +1,104 @@ +from logging import getLogger +from pathlib import Path + +import pandas as pd + +from .helpers import read_csv + + +def read_agent_parameters(path: Path) -> pd.DataFrame: + """Reads and processes agent parameters from a CSV file.""" + df = read_agent_parameters_csv(path) + return process_agent_parameters(df) + + +def read_agent_parameters_csv(path: Path) -> pd.DataFrame: + """Reads standard MUSE agent-declaration csv-files into a DataFrame.""" + required_columns = { + "search_rule", + "quantity", + "region", + "type", + "name", + "agent_share", + "decision_method", + } + data = read_csv( + path, + required_columns=required_columns, + msg=f"Reading agent parameters from {path}.", + ) + + # Check for deprecated retrofit agents + if "type" in data.columns: + retrofit_agents = data[data.type.str.lower().isin(["retrofit", "retro"])] + if not retrofit_agents.empty: + msg = ( + "Retrofit agents will be deprecated in a future release. " + "Please modify your model to use only agents of the 'New' type." + ) + getLogger(__name__).warning(msg) + + # Legacy: drop AgentNumber column + if "agent_number" in data.columns: + data = data.drop(["agent_number"], axis=1) + + # Check consistency of objectives data columns + objectives = [col for col in data.columns if col.startswith("objective")] + floats = [col for col in data.columns if col.startswith("obj_data")] + sorting = [col for col in data.columns if col.startswith("obj_sort")] + + if len(objectives) != len(floats) or len(objectives) != len(sorting): + raise ValueError( + "Agent objective, obj_data, and obj_sort columns are inconsistent in " + f"{path}" + ) + + return data + + +def process_agent_parameters(data: pd.DataFrame) -> list[dict]: + """Processes agent parameters DataFrame into a list of agent dictionaries.""" + result = [] + for _, row in data.iterrows(): + # Get objectives data + objectives = ( + row[[i.startswith("objective") for i in row.index]].dropna().to_list() + ) + sorting = row[[i.startswith("obj_sort") for i in row.index]].dropna().to_list() + floats = row[[i.startswith("obj_data") for i in row.index]].dropna().to_list() + + # Create decision parameters + decision_params = list(zip(objectives, sorting, floats)) + + agent_type = { + "new": "newcapa", + "newcapa": "newcapa", + "retrofit": "retrofit", + "retro": "retrofit", + "agent": "agent", + "default": "agent", + }[getattr(row, "type", "agent").lower()] + + # Create agent data dictionary + data = { + "name": row["name"], + "region": row.region, + "objectives": objectives, + "search_rules": row.search_rule, + "decision": {"name": row.decision_method, "parameters": decision_params}, + "agent_type": agent_type, + "quantity": row.quantity, + "share": row.agent_share, + } + + # Add optional parameters + if hasattr(row, "maturity_threshold"): + data["maturity_threshold"] = row.maturity_threshold + if hasattr(row, "spend_limit"): + data["spend_limit"] = row.spend_limit + + # Add agent data to result + result.append(data) + + return result diff --git a/src/muse/readers/csv/assets.py b/src/muse/readers/csv/assets.py new file mode 100644 index 000000000..e5abeb1e9 --- /dev/null +++ b/src/muse/readers/csv/assets.py @@ -0,0 +1,58 @@ +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .helpers import create_assets, create_multiindex, create_xarray_dataset, read_csv + + +def read_initial_capacity(path: Path) -> xr.DataArray: + """Reads and processes initial capacity data from a CSV file.""" + df = read_initial_capacity_csv(path) + return process_initial_capacity(df) + + +def read_initial_capacity_csv(path: Path) -> pd.DataFrame: + """Reads and formats data about initial capacity into a DataFrame.""" + required_columns = { + "region", + "technology", + } + return read_csv( + path, + required_columns=required_columns, + msg=f"Reading initial capacity from {path}.", + ) + + +def process_initial_capacity(data: pd.DataFrame) -> xr.DataArray: + """Processes initial capacity DataFrame into an xarray DataArray.""" + # Drop unit column if present + if "unit" in data.columns: + data = data.drop(columns=["unit"]) + + # Select year columns + year_columns = [col for col in data.columns if col.isdigit()] + + # Convert year columns to long format (i.e. single "year" column) + data = data.melt( + id_vars=["technology", "region"], + value_vars=year_columns, + var_name="year", + value_name="value", + ) + + # Create multiindex for region, technology, and year + data = create_multiindex( + data, + index_columns=["technology", "region", "year"], + index_names=["technology", "region", "year"], + drop_columns=True, + ) + + # Create Dataarray + result = create_xarray_dataset(data).value.astype(float) + + # Create assets + result = create_assets(result) + return result diff --git a/src/muse/readers/csv/commodities.py b/src/muse/readers/csv/commodities.py new file mode 100644 index 000000000..8b3c25df0 --- /dev/null +++ b/src/muse/readers/csv/commodities.py @@ -0,0 +1,57 @@ +from logging import getLogger +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .helpers import camel_to_snake, create_xarray_dataset, standardize_dataframe + + +def read_global_commodities(path: Path) -> xr.Dataset: + """Reads and processes global commodities data from a CSV file.""" + df = read_global_commodities_csv(path) + return process_global_commodities(df) + + +def read_global_commodities_csv(path: Path) -> pd.DataFrame: + """Reads commodities information from input into a DataFrame.""" + # Due to legacy reasons, users can supply both Commodity and CommodityName columns + # In this case, we need to remove the Commodity column to avoid conflicts + # This is fine because Commodity just contains a long description that isn't needed + getLogger(__name__).info(f"Reading global commodities from {path}.") + df = pd.read_csv(path) + df = df.rename(columns=camel_to_snake) + if "commodity" in df.columns and "commodity_name" in df.columns: + df = df.drop(columns=["commodity"]) + + required_columns = { + "commodity", + "commodity_type", + } + data = standardize_dataframe( + df, + required_columns=required_columns, + ) + + # Raise warning if units are not defined + if "unit" not in data.columns: + msg = ( + "No units defined for commodities. Please define units for all commodities " + "in the global commodities file." + ) + getLogger(__name__).warning(msg) + + return data + + +def process_global_commodities(data: pd.DataFrame) -> xr.Dataset: + """Processes global commodities DataFrame into an xarray Dataset.""" + # Drop description column if present. It's useful to include in the file, but we + # don't need it for the simulation. + if "description" in data.columns: + data = data.drop(columns=["description"]) + + data.index = [u for u in data.commodity] + data = data.drop("commodity", axis=1) + data.index.name = "commodity" + return create_xarray_dataset(data) diff --git a/src/muse/readers/csv/general.py b/src/muse/readers/csv/general.py new file mode 100644 index 000000000..f1bf99657 --- /dev/null +++ b/src/muse/readers/csv/general.py @@ -0,0 +1,49 @@ +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .helpers import create_xarray_dataset, read_csv + + +def read_attribute_table(path: Path) -> xr.Dataset: + """Reads and processes attribute table data from a CSV file.""" + df = read_attribute_table_csv(path) + return process_attribute_table(df) + + +def read_attribute_table_csv(path: Path) -> pd.DataFrame: + """Read a standard MUSE csv file for price projections into a DataFrame.""" + table = read_csv( + path, + required_columns=["region", "attribute", "year"], + msg=f"Reading attribute table from {path}.", + ) + return table + + +def process_attribute_table(data: pd.DataFrame) -> xr.Dataset: + """Process attribute table DataFrame into an xarray Dataset.""" + # Extract commodity columns + commodities = [ + col for col in data.columns if col not in ["region", "year", "attribute"] + ] + + # Convert commodity columns to long format (i.e. single "commodity" column) + data = data.melt( + id_vars=["region", "year", "attribute"], + value_vars=commodities, + var_name="commodity", + value_name="value", + ) + + # Pivot data over attributes + data = data.pivot( + index=["region", "year", "commodity"], + columns="attribute", + values="value", + ) + + # Create DataSet + result = create_xarray_dataset(data) + return result diff --git a/src/muse/readers/csv/helpers.py b/src/muse/readers/csv/helpers.py new file mode 100644 index 000000000..b55a2df19 --- /dev/null +++ b/src/muse/readers/csv/helpers.py @@ -0,0 +1,342 @@ +from logging import getLogger +from pathlib import Path + +import pandas as pd +import xarray as xr + +from muse.utilities import camel_to_snake + +# Global mapping of column names to their standardized versions +# This is for backwards compatibility with old file formats +COLUMN_RENAMES = { + "process_name": "technology", + "process": "technology", + "sector_name": "sector", + "region_name": "region", + "time": "year", + "commodity_name": "commodity", + "comm_type": "commodity_type", + "commodity_price": "prices", + "units_commodity_price": "units_prices", + "enduse": "end_use", + "sn": "timeslice", + "commodity_emission_factor_CO2": "emmission_factor", + "utilisation_factor": "utilization_factor", + "objsort": "obj_sort", + "objsort1": "obj_sort1", + "objsort2": "obj_sort2", + "objsort3": "obj_sort3", + "time_slice": "timeslice", + "price": "prices", +} + +# Columns who's values should be converted from camelCase to snake_case +CAMEL_TO_SNAKE_COLUMNS = [ + "tech_type", + "commodity", + "commodity_type", + "agent_share", + "attribute", + "sector", + "region", + "parameter", +] + +# Global mapping of column names to their expected types +COLUMN_TYPES = { + "year": int, + "region": str, + "technology": str, + "commodity": str, + "sector": str, + "attribute": str, + "variable": str, + "timeslice": int, # For tables that require int timeslice instead of month etc. + "name": str, + "commodity_type": str, + "tech_type": str, + "type": str, + "function_type": str, + "level": str, + "search_rule": str, + "decision_method": str, + "quantity": float, + "share": float, + "coeff": str, + "value": float, + "utilization_factor": float, + "minimum_service_factor": float, + "maturity_threshold": float, + "spend_limit": float, + "prices": float, + "emmission_factor": float, +} + +DEFAULTS = { + "cap_par": 0, + "cap_exp": 1, + "fix_par": 0, + "fix_exp": 1, + "var_par": 0, + "var_exp": 1, + "interest_rate": 0, + "utilization_factor": 1, + "minimum_service_factor": 0, + "search_rule": "all", + "decision_method": "single", +} + + +def standardize_columns(data: pd.DataFrame) -> pd.DataFrame: + """Standardizes column names in a DataFrame. + + This function: + 1. Converts column names to snake_case + 2. Applies the global COLUMN_RENAMES mapping + 3. Preserves any columns not in the mapping + + Args: + data: DataFrame to standardize + + Returns: + DataFrame with standardized column names + """ + # Drop index column if present + if data.columns[0] == "" or data.columns[0].startswith("Unnamed"): + data = data.iloc[:, 1:] + + # Convert columns to snake_case + data = data.rename(columns=camel_to_snake) + + # Then apply global mapping + data = data.rename(columns=COLUMN_RENAMES) + + # Make sure there are no duplicate columns + if len(data.columns) != len(set(data.columns)): + raise ValueError(f"Duplicate columns in {data.columns}") + + return data + + +def create_multiindex( + data: pd.DataFrame, + index_columns: list[str], + index_names: list[str], + drop_columns: bool = True, +) -> pd.DataFrame: + """Creates a MultiIndex from specified columns. + + Args: + data: DataFrame to create index from + index_columns: List of column names to use for index + index_names: List of names for the index levels + drop_columns: Whether to drop the original columns + + Returns: + DataFrame with new MultiIndex + """ + index = pd.MultiIndex.from_arrays( + [data[col] for col in index_columns], names=index_names + ) + result = data.copy() + result.index = index + if drop_columns: + result = result.drop(columns=index_columns) + return result + + +def create_xarray_dataset( + data: pd.DataFrame, + disallow_nan: bool = True, +) -> xr.Dataset: + """Creates an xarray Dataset from a DataFrame with standardized options. + + Args: + data: DataFrame to convert + disallow_nan: Whether to raise an error if NaN values are found + + Returns: + xarray Dataset + """ + result = xr.Dataset.from_dataframe(data) + if disallow_nan: + nan_coords = get_nan_coordinates(result) + if nan_coords: + raise ValueError(f"Missing data for coordinates: {nan_coords}") + + if "year" in result.coords: + result = result.assign_coords(year=result.year.astype(int)) + result = result.sortby("year") + assert len(set(result.year.values)) == result.year.data.size # no duplicates + + return result + + +def get_nan_coordinates(dataset: xr.Dataset) -> list[tuple]: + """Get coordinates of a Dataset where any data variable has NaN values.""" + any_nan = sum(var.isnull() for var in dataset.data_vars.values()) + if any_nan.any(): + return any_nan.where(any_nan, drop=True).to_dataframe(name="").index.to_list() + return [] + + +def convert_column_types(data: pd.DataFrame) -> pd.DataFrame: + """Converts DataFrame columns to their expected types. + + Args: + data: DataFrame to convert + + Returns: + DataFrame with converted column types + """ + result = data.copy() + for column, expected_type in COLUMN_TYPES.items(): + if column in result.columns: + try: + if expected_type is int: + result[column] = pd.to_numeric(result[column], downcast="integer") + elif expected_type is float: + result[column] = pd.to_numeric(result[column]).astype(float) + elif expected_type is str: + result[column] = result[column].astype(str) + except (ValueError, TypeError) as e: + raise ValueError( + f"Could not convert column '{column}' to {expected_type.__name__}: {e}" # noqa: E501 + ) + return result + + +def standardize_dataframe( + data: pd.DataFrame, + required_columns: list[str] | None = None, + exclude_extra_columns: bool = False, +) -> pd.DataFrame: + """Standardizes a DataFrame to a common format. + + Args: + data: DataFrame to standardize + required_columns: List of column names that must be present (optional) + exclude_extra_columns: If True, exclude any columns not in required_columns list + (optional). This can be important if extra columns can mess up the resulting + xarray object. + + Returns: + DataFrame containing the standardized data + """ + if required_columns is None: + required_columns = [] + + # Standardize column names + data = standardize_columns(data) + + # Convert specified column values from camelCase to snake_case + for col in CAMEL_TO_SNAKE_COLUMNS: + if col in data.columns: + data[col] = data[col].apply(camel_to_snake) + + # Fill missing values with defaults + data = data.fillna(DEFAULTS) + for col, default in DEFAULTS.items(): + if col not in data.columns and col in required_columns: + data[col] = default + + # Check/convert data types + data = convert_column_types(data) + + # Validate required columns if provided + if required_columns: + missing_columns = [col for col in required_columns if col not in data.columns] + if missing_columns: + raise ValueError(f"Missing required columns: {missing_columns}") + + # Exclude extra columns if requested + if exclude_extra_columns: + data = data[list(required_columns)] + + return data + + +def read_csv( + path: Path, + float_precision: str = "high", + required_columns: list[str] | None = None, + exclude_extra_columns: bool = False, + msg: str | None = None, +) -> pd.DataFrame: + """Reads and standardizes a CSV file into a DataFrame. + + Args: + path: Path to the CSV file + float_precision: Precision to use when reading floats + required_columns: List of column names that must be present (optional) + exclude_extra_columns: If True, exclude any columns not in required_columns list + (optional). This can be important if extra columns can mess up the resulting + xarray object. + msg: Message to log (optional) + + Returns: + DataFrame containing the standardized data + """ + # Log message + if msg: + getLogger(__name__).info(msg) + + # Check if file exists + if not path.is_file(): + raise OSError(f"{path} does not exist.") + + # Check if there's a units row (in which case we need to skip it) + with open(path) as f: + next(f) # Skip header row + first_data_row = f.readline().strip() + skiprows = [1] if first_data_row.startswith("Unit") else None + + # Read the file + data = pd.read_csv( + path, + float_precision=float_precision, + low_memory=False, + skiprows=skiprows, + ) + + # Standardize the DataFrame + return standardize_dataframe( + data, + required_columns=required_columns, + exclude_extra_columns=exclude_extra_columns, + ) + + +def check_commodities( + data: xr.Dataset | xr.DataArray, fill_missing: bool = True, fill_value: float = 0 +) -> xr.Dataset | xr.DataArray: + """Validates and optionally fills missing commodities in data.""" + from muse.commodities import COMMODITIES + + # Make sure there are no commodities in data but not in global commodities + extra_commodities = [ + c for c in data.commodity.values if c not in COMMODITIES.commodity.values + ] + if extra_commodities: + raise ValueError( + "The following commodities were not found in global commodities file: " + f"{extra_commodities}" + ) + + # Add any missing commodities with fill_value + if fill_missing: + data = data.reindex( + commodity=COMMODITIES.commodity.values, fill_value=fill_value + ) + return data + + +def create_assets(data: xr.DataArray | xr.Dataset) -> xr.DataArray | xr.Dataset: + """Creates assets from technology data.""" + # Rename technology to asset + result = data.drop_vars("technology").rename(technology="asset") + result["technology"] = "asset", data.technology.values + + # Add installed year + result["installed"] = ("asset", [int(result.year.min())] * len(result.technology)) + return result diff --git a/src/muse/readers/csv/market.py b/src/muse/readers/csv/market.py new file mode 100644 index 000000000..035f84713 --- /dev/null +++ b/src/muse/readers/csv/market.py @@ -0,0 +1,126 @@ +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .general import process_attribute_table +from .helpers import check_commodities, read_csv + + +def read_initial_market( + projections_path: Path, + base_year_import_path: Path | None = None, + base_year_export_path: Path | None = None, + currency: str | None = None, +) -> xr.Dataset: + """Reads and processes initial market data. + + Args: + projections_path: path to the projections file + base_year_import_path: path to the base year import file (optional) + base_year_export_path: path to the base year export file (optional) + currency: currency string (e.g. "USD") + + Returns: + xr.Dataset: Dataset containing initial market data. + """ + # Read projections + projections_df = read_projections_csv(projections_path) + + # Read base year export (optional) + if base_year_export_path: + export_df = read_csv( + base_year_export_path, + msg=f"Reading base year export from {base_year_export_path}.", + ) + else: + export_df = None + + # Read base year import (optional) + if base_year_import_path: + import_df = read_csv( + base_year_import_path, + msg=f"Reading base year import from {base_year_import_path}.", + ) + else: + import_df = None + + # Assemble into xarray Dataset + result = process_initial_market(projections_df, import_df, export_df, currency) + return result + + +def read_projections_csv(path: Path) -> pd.DataFrame: + """Reads projections data from a CSV file.""" + required_columns = { + "region", + "attribute", + "year", + } + projections_df = read_csv( + path, required_columns=required_columns, msg=f"Reading projections from {path}." + ) + return projections_df + + +def process_initial_market( + projections_df: pd.DataFrame, + import_df: pd.DataFrame | None, + export_df: pd.DataFrame | None, + currency: str | None = None, +) -> xr.Dataset: + """Process market data DataFrames into an xarray Dataset. + + Args: + projections_df: DataFrame containing projections data + import_df: Optional DataFrame containing import data + export_df: Optional DataFrame containing export data + currency: Currency string (e.g. "USD") + """ + from muse.commodities import COMMODITIES + from muse.timeslices import broadcast_timeslice, distribute_timeslice + + # Process projections + projections = process_attribute_table(projections_df).commodity_price.astype( + "float64" + ) + + # Process optional trade data + if export_df is not None: + base_year_export = process_attribute_table(export_df).exports.astype("float64") + else: + base_year_export = xr.zeros_like(projections) + + if import_df is not None: + base_year_import = process_attribute_table(import_df).imports.astype("float64") + else: + base_year_import = xr.zeros_like(projections) + + # Distribute data over timeslices + projections = broadcast_timeslice(projections, level=None) + base_year_export = distribute_timeslice(base_year_export, level=None) + base_year_import = distribute_timeslice(base_year_import, level=None) + + # Assemble into xarray + result = xr.Dataset( + { + "prices": projections, + "exports": base_year_export, + "imports": base_year_import, + "static_trade": base_year_import - base_year_export, + } + ) + + # Check commodities + result = check_commodities(result, fill_missing=True, fill_value=0) + + # Add units_prices coordinate + # Only added if the currency is specified and commodity units are defined + if currency and "unit" in COMMODITIES.data_vars: + units_prices = [ + f"{currency}/{COMMODITIES.sel(commodity=c).unit.item()}" + for c in result.commodity.values + ] + result = result.assign_coords(units_prices=("commodity", units_prices)) + + return result diff --git a/src/muse/readers/csv/presets.py b/src/muse/readers/csv/presets.py new file mode 100644 index 000000000..a9efe7043 --- /dev/null +++ b/src/muse/readers/csv/presets.py @@ -0,0 +1,109 @@ +from logging import getLogger +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .helpers import ( + check_commodities, + create_multiindex, + create_xarray_dataset, + read_csv, +) + + +def read_presets(presets_paths: Path) -> xr.Dataset: + """Reads and processes preset data from multiple CSV files. + + Accepts a path pattern for presets files, e.g. `Path("path/to/*Consumption.csv")`. + The file name of each file must contain a year (e.g. "2020Consumption.csv"). + """ + from glob import glob + from re import match + + # Find all files matching the path pattern + allfiles = [Path(p) for p in glob(str(presets_paths))] + if len(allfiles) == 0: + raise OSError(f"No files found with paths {presets_paths}") + + # Read all files + datas: dict[int, pd.DataFrame] = {} + for path in allfiles: + # Extract year from filename + reyear = match(r"\S*.(\d{4})\S*\.csv", path.name) + if reyear is None: + raise OSError(f"Unexpected filename {path.name}") + year = int(reyear.group(1)) + if year in datas: + raise OSError(f"Year f{year} was found twice") + + # Read data + data = read_presets_csv(path) + data["year"] = year + datas[year] = data + + # Process data + datas = process_presets(datas) + return datas + + +def read_presets_csv(path: Path) -> pd.DataFrame: + data = read_csv( + path, + required_columns=["region", "timeslice"], + msg=f"Reading presets from {path}.", + ) + + # Legacy: drop technology column and sum data (PR #448) + if "technology" in data.columns: + getLogger(__name__).warning( + f"The technology (or ProcessName) column in file {path} is " + "deprecated. Data has been summed across technologies, and this column " + "has been dropped." + ) + data = ( + data.drop(columns=["technology"]) + .groupby(["region", "timeslice"]) + .sum() + .reset_index() + ) + + return data + + +def process_presets(datas: dict[int, pd.DataFrame]) -> xr.Dataset: + """Processes preset DataFrames into an xarray Dataset.""" + from muse.commodities import COMMODITIES + from muse.timeslices import TIMESLICE + + # Combine into a single DataFrame + data = pd.concat(datas.values()) + + # Extract commodity columns + commodities = [c for c in data.columns if c in COMMODITIES.commodity.values] + + # Convert commodity columns to long format (i.e. single "commodity" column) + data = data.melt( + id_vars=["region", "year", "timeslice"], + value_vars=commodities, + var_name="commodity", + value_name="value", + ) + + # Create multiindex for region, year, timeslice and commodity + data = create_multiindex( + data, + index_columns=["region", "year", "timeslice", "commodity"], + index_names=["region", "year", "timeslice", "commodity"], + drop_columns=True, + ) + + # Create DataArray + result = create_xarray_dataset(data).value.astype(float) + + # Assign timeslices + result = result.assign_coords(timeslice=TIMESLICE.timeslice) + + # Check commodities + result = check_commodities(result, fill_missing=True, fill_value=0) + return result diff --git a/src/muse/readers/csv/regression.py b/src/muse/readers/csv/regression.py new file mode 100644 index 000000000..638d1617b --- /dev/null +++ b/src/muse/readers/csv/regression.py @@ -0,0 +1,185 @@ +from logging import getLogger +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .helpers import ( + check_commodities, + create_multiindex, + create_xarray_dataset, + read_csv, +) + + +def read_timeslice_shares(path: Path) -> xr.DataArray: + """Reads and processes timeslice shares data from a CSV file.""" + df = read_timeslice_shares_csv(path) + return process_timeslice_shares(df) + + +def read_timeslice_shares_csv(path: Path) -> pd.DataFrame: + """Reads sliceshare information into a DataFrame.""" + data = read_csv( + path, + required_columns=["region", "timeslice"], + msg=f"Reading timeslice shares from {path}.", + ) + + return data + + +def process_timeslice_shares(data: pd.DataFrame) -> xr.DataArray: + """Processes timeslice shares DataFrame into an xarray DataArray.""" + from muse.commodities import COMMODITIES + from muse.timeslices import TIMESLICE + + # Extract commodity columns + commodities = [c for c in data.columns if c in COMMODITIES.commodity.values] + + # Convert commodity columns to long format (i.e. single "commodity" column) + data = data.melt( + id_vars=["region", "timeslice"], + value_vars=commodities, + var_name="commodity", + value_name="value", + ) + + # Create multiindex for region and timeslice + data = create_multiindex( + data, + index_columns=["region", "timeslice", "commodity"], + index_names=["region", "timeslice", "commodity"], + drop_columns=True, + ) + + # Create DataArray + result = create_xarray_dataset(data).value.astype(float) + + # Assign timeslices + result = result.assign_coords(timeslice=TIMESLICE.timeslice) + + # Check commodities + result = check_commodities(result, fill_missing=True, fill_value=0) + return result + + +def read_macro_drivers(path: Path) -> pd.DataFrame: + """Reads and processes macro drivers data from a CSV file.""" + df = read_macro_drivers_csv(path) + return process_macro_drivers(df) + + +def read_macro_drivers_csv(path: Path) -> pd.DataFrame: + """Reads a standard MUSE csv file for macro drivers into a DataFrame.""" + table = read_csv( + path, + required_columns=["region", "variable"], + msg=f"Reading macro drivers from {path}.", + ) + + # Validate required variables + required_variables = ["Population", "GDP|PPP"] + missing_variables = [ + var for var in required_variables if var not in table.variable.unique() + ] + if missing_variables: + raise ValueError(f"Missing required variables in {path}: {missing_variables}") + + return table + + +def process_macro_drivers(data: pd.DataFrame) -> xr.Dataset: + """Processes macro drivers DataFrame into an xarray Dataset.""" + # Drop unit column if present + if "unit" in data.columns: + data = data.drop(columns=["unit"]) + + # Select year columns + year_columns = [col for col in data.columns if col.isdigit()] + + # Convert year columns to long format (i.e. single "year" column) + data = data.melt( + id_vars=["variable", "region"], + value_vars=year_columns, + var_name="year", + value_name="value", + ) + + # Pivot data to create Population and GDP|PPP columns + data = data.pivot( + index=["region", "year"], + columns="variable", + values="value", + ) + + # Legacy: rename Population to population and GDP|PPP to gdp + if "Population" in data.columns: + data = data.rename(columns={"Population": "population"}) + if "GDP|PPP" in data.columns: + data = data.rename(columns={"GDP|PPP": "gdp"}) + + # Create DataSet + result = create_xarray_dataset(data) + return result + + +def read_regression_parameters(path: Path) -> xr.Dataset: + """Reads and processes regression parameters from a CSV file.""" + df = read_regression_parameters_csv(path) + return process_regression_parameters(df) + + +def read_regression_parameters_csv(path: Path) -> pd.DataFrame: + """Reads the regression parameters from a MUSE csv file into a DataFrame.""" + table = read_csv( + path, + required_columns=["region", "function_type", "coeff"], + msg=f"Reading regression parameters from {path}.", + ) + + # Legacy: warn about "sector" column + if "sector" in table.columns: + getLogger(__name__).warning( + f"The sector column (in file {path}) is deprecated. Please remove." + ) + + return table + + +def process_regression_parameters(data: pd.DataFrame) -> xr.Dataset: + """Processes regression parameters DataFrame into an xarray Dataset.""" + from muse.commodities import COMMODITIES + + # Extract commodity columns + commodities = [c for c in data.columns if c in COMMODITIES.commodity.values] + + # Melt to long format + melted = data.melt( + id_vars=["sector", "region", "function_type", "coeff"], + value_vars=commodities, + var_name="commodity", + value_name="value", + ) + + # Extract sector -> function_type mapping + sector_to_ftype = melted.drop_duplicates(["sector", "function_type"])[ + ["sector", "function_type"] + ].set_index("sector")["function_type"] + + # Pivot to create coefficient variables + pivoted = melted.pivot_table( + index=["sector", "region", "commodity"], columns="coeff", values="value" + ) + + # Create dataset and add function_type + result = create_xarray_dataset(pivoted) + result["function_type"] = xr.DataArray( + sector_to_ftype[result.sector.values].astype(object), + dims=["sector"], + name="function_type", + ) + + # Check commodities + result = check_commodities(result, fill_missing=True, fill_value=0) + return result diff --git a/src/muse/readers/csv/technologies.py b/src/muse/readers/csv/technologies.py new file mode 100644 index 000000000..5e4d71663 --- /dev/null +++ b/src/muse/readers/csv/technologies.py @@ -0,0 +1,368 @@ +from logging import getLogger +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .helpers import ( + check_commodities, + create_multiindex, + create_xarray_dataset, + read_csv, +) + + +def read_technodictionary(path: Path) -> xr.Dataset: + """Reads and processes technodictionary data from a CSV file.""" + df = read_technodictionary_csv(path) + return process_technodictionary(df) + + +def read_technodictionary_csv(path: Path) -> pd.DataFrame: + """Reads and formats technodata into a DataFrame.""" + required_columns = { + "cap_exp", + "region", + "var_par", + "fix_exp", + "interest_rate", + "utilization_factor", + "minimum_service_factor", + "year", + "cap_par", + "var_exp", + "technology", + "technical_life", + "fix_par", + } + data = read_csv( + path, + required_columns=required_columns, + msg=f"Reading technodictionary from {path}.", + ) + + # Check for deprecated columns + if "fuel" in data.columns: + msg = ( + f"The 'fuel' column in {path} has been deprecated. " + "This information is now determined from CommIn files. " + "Please remove this column from your Technodata files." + ) + getLogger(__name__).warning(msg) + if "end_use" in data.columns: + msg = ( + f"The 'end_use' column in {path} has been deprecated. " + "This information is now determined from CommOut files. " + "Please remove this column from your Technodata files." + ) + getLogger(__name__).warning(msg) + if "scaling_size" in data.columns: + msg = ( + f"The 'scaling_size' column in {path} has been deprecated. " + "Please remove this column from your Technodata files." + ) + getLogger(__name__).warning(msg) + + return data + + +def process_technodictionary(data: pd.DataFrame) -> xr.Dataset: + """Processes technodictionary DataFrame into an xarray Dataset.""" + # Create multiindex for technology and region + data = create_multiindex( + data, + index_columns=["technology", "region", "year"], + index_names=["technology", "region", "year"], + drop_columns=True, + ) + + # Create dataset + result = create_xarray_dataset(data) + + # Handle tech_type if present + if "type" in result.variables: + result["tech_type"] = result.type.isel(region=0, year=0) + + return result + + +def read_technodata_timeslices(path: Path) -> xr.Dataset: + """Reads and processes technodata timeslices from a CSV file.""" + df = read_technodata_timeslices_csv(path) + return process_technodata_timeslices(df) + + +def read_technodata_timeslices_csv(path: Path) -> pd.DataFrame: + """Reads and formats technodata timeslices into a DataFrame.""" + from muse.timeslices import TIMESLICE + + timeslice_columns = set(TIMESLICE.coords["timeslice"].indexes["timeslice"].names) + required_columns = { + "utilization_factor", + "technology", + "minimum_service_factor", + "region", + "year", + } | timeslice_columns + return read_csv( + path, + required_columns=required_columns, + exclude_extra_columns=True, + msg=f"Reading technodata timeslices from {path}.", + ) + + +def process_technodata_timeslices(data: pd.DataFrame) -> xr.Dataset: + """Processes technodata timeslices DataFrame into an xarray Dataset.""" + from muse.timeslices import TIMESLICE, sort_timeslices + + # Create multiindex for all columns except factor columns + factor_columns = ["utilization_factor", "minimum_service_factor", "obj_sort"] + index_columns = [col for col in data.columns if col not in factor_columns] + data = create_multiindex( + data, + index_columns=index_columns, + index_names=index_columns, + drop_columns=True, + ) + + # Create dataset + result = create_xarray_dataset(data) + + # Stack timeslice levels (month, day, hour) into a single timeslice dimension + timeslice_levels = TIMESLICE.coords["timeslice"].indexes["timeslice"].names + if all(level in result.dims for level in timeslice_levels): + result = result.stack(timeslice=timeslice_levels) + return sort_timeslices(result) + + +def read_io_technodata(path: Path) -> xr.Dataset: + """Reads and processes input/output technodata from a CSV file.""" + df = read_io_technodata_csv(path) + return process_io_technodata(df) + + +def read_io_technodata_csv(path: Path) -> pd.DataFrame: + """Reads process inputs or outputs into a DataFrame.""" + data = read_csv( + path, + required_columns=["technology", "region", "year"], + msg=f"Reading IO technodata from {path}.", + ) + + # Unspecified Level values default to "fixed" + if "level" in data.columns: + data["level"] = data["level"].fillna("fixed") + else: + # Particularly relevant to outputs files where the Level column is omitted by + # default, as only "fixed" outputs are allowed. + data["level"] = "fixed" + + return data + + +def process_io_technodata(data: pd.DataFrame) -> xr.Dataset: + """Processes IO technodata DataFrame into an xarray Dataset.""" + from muse.commodities import COMMODITIES + + # Extract commodity columns + commodities = [c for c in data.columns if c in COMMODITIES.commodity.values] + + # Convert commodity columns to long format (i.e. single "commodity" column) + data = data.melt( + id_vars=["technology", "region", "year", "level"], + value_vars=commodities, + var_name="commodity", + value_name="value", + ) + + # Pivot data to create fixed and flexible columns + data = data.pivot( + index=["technology", "region", "year", "commodity"], + columns="level", + values="value", + ) + + # Create xarray dataset + result = create_xarray_dataset(data) + + # Fill in flexible data + if "flexible" in result.data_vars: + result["flexible"] = result.flexible.fillna(0) + else: + result["flexible"] = xr.zeros_like(result.fixed).rename("flexible") + + # Check commodities + result = check_commodities(result, fill_missing=True, fill_value=0) + return result + + +def read_technologies( + technodata_path: Path, + comm_out_path: Path, + comm_in_path: Path, + time_framework: list[int], + interpolation_mode: str = "linear", + technodata_timeslices_path: Path | None = None, +) -> xr.Dataset: + """Reads and processes technology data from multiple CSV files. + + Will also interpolate data to the time framework if provided. + + Args: + technodata_path: path to the technodata file + comm_out_path: path to the comm_out file + comm_in_path: path to the comm_in file + time_framework: list of years to interpolate data to + interpolation_mode: Interpolation mode to use + technodata_timeslices_path: path to the technodata_timeslices file + + Returns: + xr.Dataset: Dataset containing the processed technology data. Any fields + that differ by year will contain a "year" dimension interpolated to the + time framework. Other fields will not have a "year" dimension. + """ + # Read all data + technodata = read_technodictionary(technodata_path) + comm_out = read_io_technodata(comm_out_path) + comm_in = read_io_technodata(comm_in_path) + technodata_timeslices = ( + read_technodata_timeslices(technodata_timeslices_path) + if technodata_timeslices_path + else None + ) + + # Assemble xarray Dataset + return process_technologies( + technodata, + comm_out, + comm_in, + time_framework, + interpolation_mode, + technodata_timeslices, + ) + + +def process_technologies( + technodata: xr.Dataset, + comm_out: xr.Dataset, + comm_in: xr.Dataset, + time_framework: list[int], + interpolation_mode: str = "linear", + technodata_timeslices: xr.Dataset | None = None, +) -> xr.Dataset: + """Processes technology data DataFrames into an xarray Dataset.""" + from muse.commodities import COMMODITIES, CommodityUsage + from muse.timeslices import drop_timeslice + from muse.utilities import interpolate_technodata + + # Process inputs/outputs + ins = comm_in.rename(flexible="flexible_inputs", fixed="fixed_inputs") + outs = comm_out.rename(flexible="flexible_outputs", fixed="fixed_outputs") + + # Legacy: Remove flexible outputs + if not (outs["flexible_outputs"] == 0).all(): + raise ValueError( + "'flexible' outputs are not permitted. All outputs must be 'fixed'" + ) + outs = outs.drop_vars("flexible_outputs") + + # Collect all years from the time framework and data files + time_framework = list( + set(time_framework).union( + technodata.year.values.tolist(), + ins.year.values.tolist(), + outs.year.values.tolist(), + technodata_timeslices.year.values.tolist() if technodata_timeslices else [], + ) + ) + + # Interpolate data to match the time framework + technodata = interpolate_technodata(technodata, time_framework, interpolation_mode) + outs = interpolate_technodata(outs, time_framework, interpolation_mode) + ins = interpolate_technodata(ins, time_framework, interpolation_mode) + if technodata_timeslices: + technodata_timeslices = interpolate_technodata( + technodata_timeslices, time_framework, interpolation_mode + ) + + # Merge inputs/outputs with technodata + technodata = technodata.merge(outs).merge(ins) + + # Merge technodata_timeslices if provided. This will prioritise values defined in + # technodata_timeslices, and fallback to the non-timesliced technodata for any + # values that are not defined in technodata_timeslices. + if technodata_timeslices: + technodata["utilization_factor"] = ( + technodata_timeslices.utilization_factor.combine_first( + technodata.utilization_factor + ) + ) + technodata["minimum_service_factor"] = drop_timeslice( + technodata_timeslices.minimum_service_factor.combine_first( + technodata.minimum_service_factor + ) + ) + + # Check commodities + technodata = check_commodities(technodata, fill_missing=False) + + # Add info about commodities + technodata = technodata.merge(COMMODITIES.sel(commodity=technodata.commodity)) + + # Add commodity usage flags + technodata["comm_usage"] = ( + "commodity", + CommodityUsage.from_technologies(technodata).values, + ) + technodata = technodata.drop_vars("commodity_type") + + # Check utilization and minimum service factors + check_utilization_and_minimum_service_factors(technodata) + + return technodata + + +def check_utilization_and_minimum_service_factors(data: xr.Dataset) -> None: + """Check utilization and minimum service factors in an xarray dataset. + + Args: + data: xarray Dataset containing utilization_factor and minimum_service_factor + """ + if "utilization_factor" not in data.data_vars: + raise ValueError( + "A technology needs to have a utilization factor defined for every " + "timeslice." + ) + + # Check UF not all zero (sum across timeslice dimension if it exists) + if "timeslice" in data.dims: + utilization_sum = data.utilization_factor.sum(dim="timeslice") + else: + utilization_sum = data.utilization_factor + + if (utilization_sum == 0).any(): + raise ValueError( + "A technology can not have a utilization factor of 0 for every timeslice." + ) + + # Check UF in range + utilization = data.utilization_factor + if not ((utilization >= 0) & (utilization <= 1)).all(): + raise ValueError( + "Utilization factor values must all be between 0 and 1 inclusive." + ) + + # Check MSF in range + min_service_factor = data.minimum_service_factor + if not ((min_service_factor >= 0) & (min_service_factor <= 1)).all(): + raise ValueError( + "Minimum service factor values must all be between 0 and 1 inclusive." + ) + + # Check UF not below MSF + if (data.utilization_factor < data.minimum_service_factor).any(): + raise ValueError( + "Utilization factors must all be greater than or equal " + "to their corresponding minimum service factors." + ) diff --git a/src/muse/readers/csv/trade.py b/src/muse/readers/csv/trade.py new file mode 100644 index 000000000..29efa305a --- /dev/null +++ b/src/muse/readers/csv/trade.py @@ -0,0 +1,106 @@ +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .helpers import ( + create_assets, + create_multiindex, + create_xarray_dataset, + read_csv, +) + + +def read_trade_technodata(path: Path) -> xr.Dataset: + """Reads and processes trade technodata from a CSV file.""" + df = read_trade_technodata_csv(path) + return process_trade_technodata(df) + + +def read_trade_technodata_csv(path: Path) -> pd.DataFrame: + required_columns = {"technology", "region", "parameter"} + return read_csv( + path, + required_columns=required_columns, + msg=f"Reading trade technodata from {path}.", + ) + + +def process_trade_technodata(data: pd.DataFrame) -> xr.Dataset: + # Drop unit column if present + if "unit" in data.columns: + data = data.drop(columns=["unit"]) + + # Select region columns + # TODO: this is a bit unsafe as user could supply other columns + regions = [ + col for col in data.columns if col not in ["technology", "region", "parameter"] + ] + + # Melt data over regions + data = data.melt( + id_vars=["technology", "region", "parameter"], + value_vars=regions, + var_name="dst_region", + value_name="value", + ) + + # Pivot data over parameters + data = data.pivot( + index=["technology", "region", "dst_region"], + columns="parameter", + values="value", + ) + + # Create DataSet + return create_xarray_dataset(data) + + +def read_existing_trade(path: Path) -> xr.DataArray: + """Reads and processes existing trade data from a CSV file.""" + df = read_existing_trade_csv(path) + return process_existing_trade(df) + + +def read_existing_trade_csv(path: Path) -> pd.DataFrame: + required_columns = { + "region", + "technology", + "year", + } + return read_csv( + path, + required_columns=required_columns, + msg=f"Reading existing trade from {path}.", + ) + + +def process_existing_trade(data: pd.DataFrame) -> xr.DataArray: + # Select region columns + # TODO: this is a bit unsafe as user could supply other columns + regions = [ + col for col in data.columns if col not in ["technology", "region", "year"] + ] + + # Melt data over regions + data = data.melt( + id_vars=["technology", "region", "year"], + value_vars=regions, + var_name="dst_region", + value_name="value", + ) + + # Create multiindex for region, dst_region, technology and year + data = create_multiindex( + data, + index_columns=["region", "dst_region", "technology", "year"], + index_names=["region", "dst_region", "technology", "year"], + drop_columns=True, + ) + + # Create DataArray + result = create_xarray_dataset(data).value.astype(float) + + # Create assets from technologies + result = create_assets(result) + return result From 6b8fc9a79094667f1b17317bfc0162e064afd195 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Fri, 8 Aug 2025 16:15:07 +0100 Subject: [PATCH 2/7] Update import paths --- tests/test_readers.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/test_readers.py b/tests/test_readers.py index 924dcacff..cae57d94d 100644 --- a/tests/test_readers.py +++ b/tests/test_readers.py @@ -187,7 +187,9 @@ def test_suffix_path_formatting(suffix, tmp_path): def test_check_utilization_and_minimum_service(): """Test combined validation of utilization and minimum service factors.""" - from muse.readers.csv import check_utilization_and_minimum_service_factors + from muse.readers.csv.technologies import ( + check_utilization_and_minimum_service_factors, + ) # Test valid case - create dataset with proper dimensions ds = xr.Dataset( @@ -240,7 +242,9 @@ def test_check_utilization_and_minimum_service(): def test_check_utilization_not_all_zero_fail(): """Test validation fails when all utilization factors are zero.""" - from muse.readers.csv import check_utilization_and_minimum_service_factors + from muse.readers.csv.technologies import ( + check_utilization_and_minimum_service_factors, + ) ds = xr.Dataset( { @@ -260,7 +264,9 @@ def test_check_utilization_not_all_zero_fail(): ) def test_check_utilization_in_range_fail(values): """Test validation fails for utilization factors outside valid range.""" - from muse.readers.csv import check_utilization_and_minimum_service_factors + from muse.readers.csv.technologies import ( + check_utilization_and_minimum_service_factors, + ) ds = xr.Dataset( { @@ -277,7 +283,7 @@ def test_check_utilization_in_range_fail(values): def test_get_nan_coordinates(): """Test get_nan_coordinates for various scenarios.""" - from muse.readers.csv import get_nan_coordinates + from muse.readers.csv.helpers import get_nan_coordinates # Test 1: Explicit NaN values df1 = pd.DataFrame( From cb79c08f6f705e6e9af2e12ebda5f3633ee540fd Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Mon, 11 Aug 2025 09:26:00 +0100 Subject: [PATCH 3/7] from __future__ import annotations --- src/muse/readers/csv/helpers.py | 2 ++ src/muse/readers/csv/market.py | 2 ++ src/muse/readers/csv/regression.py | 2 ++ src/muse/readers/csv/technologies.py | 2 ++ 4 files changed, 8 insertions(+) diff --git a/src/muse/readers/csv/helpers.py b/src/muse/readers/csv/helpers.py index b55a2df19..c5909572a 100644 --- a/src/muse/readers/csv/helpers.py +++ b/src/muse/readers/csv/helpers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from logging import getLogger from pathlib import Path diff --git a/src/muse/readers/csv/market.py b/src/muse/readers/csv/market.py index 035f84713..3549669fc 100644 --- a/src/muse/readers/csv/market.py +++ b/src/muse/readers/csv/market.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path import pandas as pd diff --git a/src/muse/readers/csv/regression.py b/src/muse/readers/csv/regression.py index 638d1617b..1ab78e75f 100644 --- a/src/muse/readers/csv/regression.py +++ b/src/muse/readers/csv/regression.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from logging import getLogger from pathlib import Path diff --git a/src/muse/readers/csv/technologies.py b/src/muse/readers/csv/technologies.py index 5e4d71663..680f2036f 100644 --- a/src/muse/readers/csv/technologies.py +++ b/src/muse/readers/csv/technologies.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from logging import getLogger from pathlib import Path From 36494d4a05bfff6a75aafddf8030dc2f43b16bef Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Mon, 11 Aug 2025 09:53:19 +0100 Subject: [PATCH 4/7] Fix test fixtures --- tests/test_csv_readers.py | 18 ---------- tests/test_read_csv.py | 73 +++++++++++++++++++++++++-------------- 2 files changed, 47 insertions(+), 44 deletions(-) diff --git a/tests/test_csv_readers.py b/tests/test_csv_readers.py index d02c68125..ecc8ae4f4 100644 --- a/tests/test_csv_readers.py +++ b/tests/test_csv_readers.py @@ -127,24 +127,6 @@ def assert_single_coordinate(data, selection, expected): ) -@fixture -def timeslice(): - """Sets up global timeslicing scheme to match the default model.""" - from muse.timeslices import setup_module - - timeslice = """ - [timeslices] - all-year.all-week.night = 1460 - all-year.all-week.morning = 1460 - all-year.all-week.afternoon = 1460 - all-year.all-week.early-peak = 1460 - all-year.all-week.late-peak = 1460 - all-year.all-week.evening = 1460 - """ - - setup_module(timeslice) - - @fixture def model_path(tmp_path): """Creates temporary folder containing the default model.""" diff --git a/tests/test_read_csv.py b/tests/test_read_csv.py index a09e10c22..4ed45c8bd 100644 --- a/tests/test_read_csv.py +++ b/tests/test_read_csv.py @@ -1,52 +1,51 @@ -import pytest +from __future__ import annotations + +from pytest import fixture from muse import examples -from muse.readers.csv import ( - read_agent_parameters_csv, - read_existing_trade_csv, - read_global_commodities_csv, - read_initial_capacity_csv, - read_macro_drivers_csv, - read_presets_csv, - read_projections_csv, - read_regression_parameters_csv, - read_technodata_timeslices_csv, - read_technodictionary_csv, - read_timeslice_shares_csv, - read_trade_technodata_csv, -) - - -@pytest.fixture +from muse.readers.toml import read_settings + + +@fixture def model_path(tmp_path): """Creates temporary folder containing the default model.""" examples.copy_model(name="default", path=tmp_path) - return tmp_path / "model" + path = tmp_path / "model" + read_settings(path / "settings.toml") # setup globals + return path -@pytest.fixture +@fixture def timeslice_model_path(tmp_path): """Creates temporary folder containing the default model.""" examples.copy_model(name="default_timeslice", path=tmp_path) - return tmp_path / "model" + path = tmp_path / "model" + read_settings(path / "settings.toml") # setup globals + return path -@pytest.fixture +@fixture def trade_model_path(tmp_path): - """Creates temporary folder containing the default model.""" + """Creates temporary folder containing the trade model.""" examples.copy_model(name="trade", path=tmp_path) - return tmp_path / "model" + path = tmp_path / "model" + read_settings(path / "settings.toml") # setup globals + return path -@pytest.fixture +@fixture def correlation_model_path(tmp_path): """Creates temporary folder containing the correlation model.""" examples.copy_model(name="default_correlation", path=tmp_path) - return tmp_path / "model" + path = tmp_path / "model" + read_settings(path / "settings.toml") # setup globals + return path def test_read_technodictionary_csv(model_path): """Test reading the technodictionary CSV file.""" + from muse.readers.csv.technologies import read_technodictionary_csv + technodictionary_path = model_path / "power" / "Technodata.csv" technodictionary_df = read_technodictionary_csv(technodictionary_path) assert technodictionary_df is not None @@ -76,6 +75,8 @@ def test_read_technodictionary_csv(model_path): def test_read_technodata_timeslices_csv(timeslice_model_path): """Test reading the technodata timeslices CSV file.""" + from muse.readers.csv.technologies import read_technodata_timeslices_csv + timeslices_path = timeslice_model_path / "power" / "TechnodataTimeslices.csv" timeslices_df = read_technodata_timeslices_csv(timeslices_path) mandatory_columns = { @@ -95,6 +96,8 @@ def test_read_technodata_timeslices_csv(timeslice_model_path): def test_read_initial_capacity_csv(model_path): """Test reading the initial capacity CSV file.""" + from muse.readers.csv.assets import read_initial_capacity_csv + capacity_path = model_path / "power" / "ExistingCapacity.csv" capacity_df = read_initial_capacity_csv(capacity_path) mandatory_columns = { @@ -115,6 +118,8 @@ def test_read_initial_capacity_csv(model_path): def test_read_global_commodities_csv(model_path): """Test reading the global commodities CSV file.""" + from muse.readers.csv.commodities import read_global_commodities_csv + commodities_path = model_path / "GlobalCommodities.csv" commodities_df = read_global_commodities_csv(commodities_path) mandatory_columns = { @@ -127,6 +132,8 @@ def test_read_global_commodities_csv(model_path): def test_read_timeslice_shares_csv(correlation_model_path): """Test reading the timeslice shares CSV file.""" + from muse.readers.csv.regression import read_timeslice_shares_csv + shares_path = ( correlation_model_path / "residential_presets" / "TimesliceSharepreset.csv" ) @@ -147,6 +154,8 @@ def test_read_timeslice_shares_csv(correlation_model_path): def test_read_agent_parameters_csv(model_path): """Test reading the agent parameters CSV file.""" + from muse.readers.csv.agents import read_agent_parameters_csv + agents_path = model_path / "Agents.csv" agents_df = read_agent_parameters_csv(agents_path) mandatory_columns = { @@ -168,6 +177,8 @@ def test_read_agent_parameters_csv(model_path): def test_read_macro_drivers_csv(correlation_model_path): """Test reading the macro drivers CSV file.""" + from muse.readers.csv.regression import read_macro_drivers_csv + macro_path = correlation_model_path / "residential_presets" / "Macrodrivers.csv" macro_df = read_macro_drivers_csv(macro_path) mandatory_columns = { @@ -185,6 +196,8 @@ def test_read_macro_drivers_csv(correlation_model_path): def test_read_projections_csv(model_path): """Test reading the projections CSV file.""" + from muse.readers.csv.market import read_projections_csv + projections_path = model_path / "Projections.csv" projections_df = read_projections_csv(projections_path) mandatory_columns = { @@ -202,6 +215,8 @@ def test_read_projections_csv(model_path): def test_read_regression_parameters_csv(correlation_model_path): """Test reading the regression parameters CSV file.""" + from muse.readers.csv.regression import read_regression_parameters_csv + regression_path = ( correlation_model_path / "residential_presets" / "regressionparameters.csv" ) @@ -223,6 +238,8 @@ def test_read_regression_parameters_csv(correlation_model_path): def test_read_presets_csv(model_path): """Test reading the presets CSV files.""" + from muse.readers.csv.presets import read_presets_csv + presets_path = model_path / "residential_presets" / "Residential2020Consumption.csv" presets_df = read_presets_csv(presets_path) @@ -238,6 +255,8 @@ def test_read_presets_csv(model_path): def test_read_existing_trade_csv(trade_model_path): """Test reading the existing trade CSV file.""" + from muse.readers.csv.trade import read_existing_trade_csv + trade_path = trade_model_path / "power" / "ExistingTrade.csv" trade_df = read_existing_trade_csv(trade_path) mandatory_columns = { @@ -251,6 +270,8 @@ def test_read_existing_trade_csv(trade_model_path): def test_read_trade_technodata(trade_model_path): """Test reading the trade technodata CSV file.""" + from muse.readers.csv.trade import read_trade_technodata_csv + trade_path = trade_model_path / "power" / "TradeTechnodata.csv" trade_df = read_trade_technodata_csv(trade_path) mandatory_columns = {"technology", "region", "parameter"} From 695a6c516e5de08142532edd17b6dc6d17ea9e80 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Mon, 11 Aug 2025 10:39:18 +0100 Subject: [PATCH 5/7] Share test fixtures --- tests/conftest.py | 47 +++++++++++++++ tests/test_csv_readers.py | 122 ++++++++++++++------------------------ tests/test_read_csv.py | 91 ++++++++++------------------ tests/test_wizard.py | 70 ++++++++++------------ 4 files changed, 153 insertions(+), 177 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 3cebe1449..379640a1d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,8 +11,10 @@ from pytest import fixture from xarray import DataArray, Dataset +from muse import examples from muse.__main__ import patched_broadcast_compat_data from muse.agents import Agent +from muse.readers.toml import read_settings @contextmanager @@ -598,3 +600,48 @@ def rng(request): from numpy.random import default_rng return default_rng(getattr(request.config.option, "randomly_seed", None)) + + +@fixture +def default_model_path(tmp_path): + """Creates temporary folder containing the default model.""" + examples.copy_model(name="default", path=tmp_path) + path = tmp_path / "model" + read_settings(path / "settings.toml") # setup globals + return path + + +@fixture +def default_timeslice_model_path(tmp_path): + """Creates temporary folder containing the default model.""" + examples.copy_model(name="default_timeslice", path=tmp_path) + path = tmp_path / "model" + read_settings(path / "settings.toml") # setup globals + return path + + +@fixture +def default_retro_model_path(tmp_path): + """Creates temporary folder containing the default_retro model.""" + examples.copy_model(name="default_retro", path=tmp_path) + path = tmp_path / "model" + read_settings(path / "settings.toml") # setup globals + return path + + +@fixture +def trade_model_path(tmp_path): + """Creates temporary folder containing the trade model.""" + examples.copy_model(name="trade", path=tmp_path) + path = tmp_path / "model" + read_settings(path / "settings.toml") # setup globals + return path + + +@fixture +def default_correlation_model_path(tmp_path): + """Creates temporary folder containing the correlation model.""" + examples.copy_model(name="default_correlation", path=tmp_path) + path = tmp_path / "model" + read_settings(path / "settings.toml") # setup globals + return path diff --git a/tests/test_csv_readers.py b/tests/test_csv_readers.py index ecc8ae4f4..6feb9ca8d 100644 --- a/tests/test_csv_readers.py +++ b/tests/test_csv_readers.py @@ -4,9 +4,7 @@ import numpy as np import xarray as xr -from pytest import fixture -from muse import examples from muse.readers.toml import read_settings # Common test data @@ -127,46 +125,10 @@ def assert_single_coordinate(data, selection, expected): ) -@fixture -def model_path(tmp_path): - """Creates temporary folder containing the default model.""" - examples.copy_model(name="default", path=tmp_path) - path = tmp_path / "model" - read_settings(path / "settings.toml") # setup globals - return path - - -@fixture -def timeslice_model_path(tmp_path): - """Creates temporary folder containing the default model.""" - examples.copy_model(name="default_timeslice", path=tmp_path) - path = tmp_path / "model" - read_settings(path / "settings.toml") # setup globals - return path - - -@fixture -def trade_model_path(tmp_path): - """Creates temporary folder containing the trade model.""" - examples.copy_model(name="trade", path=tmp_path) - path = tmp_path / "model" - read_settings(path / "settings.toml") # setup globals - return path - - -@fixture -def correlation_model_path(tmp_path): - """Creates temporary folder containing the correlation model.""" - examples.copy_model(name="default_correlation", path=tmp_path) - path = tmp_path / "model" - read_settings(path / "settings.toml") # setup globals - return path - - -def test_read_global_commodities(model_path): +def test_read_global_commodities(default_model_path): from muse.readers.csv import read_global_commodities - path = model_path / "GlobalCommodities.csv" + path = default_model_path / "GlobalCommodities.csv" data = read_global_commodities(path) # Check data against schema @@ -192,10 +154,10 @@ def test_read_global_commodities(model_path): assert_single_coordinate(data, coord, expected) -def test_read_presets(model_path): +def test_read_presets(default_model_path): from muse.readers.csv import read_presets - data = read_presets(str(model_path / "residential_presets" / "*.csv")) + data = read_presets(str(default_model_path / "residential_presets" / "*.csv")) # Check data against schema expected_schema = DataArraySchema( @@ -237,10 +199,12 @@ def test_read_presets(model_path): ) -def test_read_initial_market(model_path): +def test_read_initial_market(default_model_path): from muse.readers.csv import read_initial_market - data = read_initial_market(model_path / "Projections.csv", currency="MUS$2010") + data = read_initial_market( + default_model_path / "Projections.csv", currency="MUS$2010" + ) # Check data against schema expected_schema = DatasetSchema( @@ -300,10 +264,10 @@ def test_read_initial_market(model_path): assert_single_coordinate(data, coord, expected) -def test_read_technodictionary(model_path): +def test_read_technodictionary(default_model_path): from muse.readers.csv import read_technodictionary - data = read_technodictionary(model_path / "power" / "Technodata.csv") + data = read_technodictionary(default_model_path / "power" / "Technodata.csv") # Check data against schema expected_schema = DatasetSchema( @@ -370,11 +334,11 @@ def test_read_technodictionary(model_path): assert_single_coordinate(data, coord, expected) -def test_read_technodata_timeslices(timeslice_model_path): +def test_read_technodata_timeslices(default_timeslice_model_path): from muse.readers.csv import read_technodata_timeslices data = read_technodata_timeslices( - timeslice_model_path / "power" / "TechnodataTimeslices.csv" + default_timeslice_model_path / "power" / "TechnodataTimeslices.csv" ) # Check data against schema @@ -420,10 +384,10 @@ def test_read_technodata_timeslices(timeslice_model_path): assert_single_coordinate(data, coord, expected) -def test_read_io_technodata(model_path): +def test_read_io_technodata(default_model_path): from muse.readers.csv import read_io_technodata - data = read_io_technodata(model_path / "power" / "CommIn.csv") + data = read_io_technodata(default_model_path / "power" / "CommIn.csv") # Check data against schema expected_schema = DatasetSchema( @@ -458,10 +422,10 @@ def test_read_io_technodata(model_path): assert_single_coordinate(data, coord, expected) -def test_read_initial_capacity(model_path): +def test_read_initial_capacity(default_model_path): from muse.readers.csv import read_initial_capacity - data = read_initial_capacity(model_path / "power" / "ExistingCapacity.csv") + data = read_initial_capacity(default_model_path / "power" / "ExistingCapacity.csv") # Check data against schema expected_schema = DataArraySchema( @@ -494,10 +458,10 @@ def test_read_initial_capacity(model_path): assert data.sel(region="r1", asset=0, year=2020).item() == 1 -def test_read_agent_parameters(model_path): +def test_read_agent_parameters(default_model_path): from muse.readers.csv import read_agent_parameters - data = read_agent_parameters(model_path / "Agents.csv") + data = read_agent_parameters(default_model_path / "Agents.csv") assert isinstance(data, list) assert len(data) == 1 @@ -598,11 +562,13 @@ def test_read_trade_technodata(trade_model_path): assert_single_coordinate(data, coord, expected) -def test_read_timeslice_shares(correlation_model_path): +def test_read_timeslice_shares(default_correlation_model_path): from muse.readers.csv import read_timeslice_shares data = read_timeslice_shares( - correlation_model_path / "residential_presets" / "TimesliceSharepreset.csv" + default_correlation_model_path + / "residential_presets" + / "TimesliceSharepreset.csv" ) # Check data against schema @@ -640,11 +606,11 @@ def test_read_timeslice_shares(correlation_model_path): assert data.sel(**coord).item() == 0.071 -def test_read_macro_drivers(correlation_model_path): +def test_read_macro_drivers(default_correlation_model_path): from muse.readers.csv import read_macro_drivers data = read_macro_drivers( - correlation_model_path / "residential_presets" / "Macrodrivers.csv" + default_correlation_model_path / "residential_presets" / "Macrodrivers.csv" ) # Check data against schema @@ -679,11 +645,13 @@ def test_read_macro_drivers(correlation_model_path): assert_single_coordinate(data, coord, expected) -def test_read_regression_parameters(correlation_model_path): +def test_read_regression_parameters(default_correlation_model_path): from muse.readers.csv import read_regression_parameters data = read_regression_parameters( - correlation_model_path / "residential_presets" / "regressionparameters.csv" + default_correlation_model_path + / "residential_presets" + / "regressionparameters.csv" ) # Check data against schema @@ -729,14 +697,14 @@ def test_read_regression_parameters(correlation_model_path): assert_single_coordinate(data, coord, expected) -def test_read_technologies(model_path): +def test_read_technologies(default_model_path): from muse.readers.csv import read_technologies # Read technologies data = read_technologies( - technodata_path=model_path / "power" / "Technodata.csv", - comm_out_path=model_path / "power" / "CommOut.csv", - comm_in_path=model_path / "power" / "CommIn.csv", + technodata_path=default_model_path / "power" / "Technodata.csv", + comm_out_path=default_model_path / "power" / "CommOut.csv", + comm_in_path=default_model_path / "power" / "CommIn.csv", time_framework=[2020, 2025, 2030, 2035, 2040, 2045, 2050], interpolation_mode="linear", ) @@ -785,15 +753,15 @@ def test_read_technologies(model_path): ) -def test_read_technologies__timeslice(timeslice_model_path): +def test_read_technologies__timeslice(default_timeslice_model_path): """Testing the read_technologies function with the timeslice model.""" from muse.readers.csv import read_technologies data = read_technologies( - technodata_path=timeslice_model_path / "power" / "Technodata.csv", - comm_out_path=timeslice_model_path / "power" / "CommOut.csv", - comm_in_path=timeslice_model_path / "power" / "CommIn.csv", - technodata_timeslices_path=timeslice_model_path + technodata_path=default_timeslice_model_path / "power" / "Technodata.csv", + comm_out_path=default_timeslice_model_path / "power" / "CommOut.csv", + comm_in_path=default_timeslice_model_path / "power" / "CommIn.csv", + technodata_timeslices_path=default_timeslice_model_path / "power" / "TechnodataTimeslices.csv", time_framework=[2020, 2025, 2030, 2035, 2040, 2045, 2050], @@ -850,10 +818,10 @@ def test_read_technologies__timeslice(timeslice_model_path): ) -def test_read_technodata(model_path): +def test_read_technodata(default_model_path): from muse.readers.toml import read_settings, read_technodata - settings = read_settings(model_path / "settings.toml") + settings = read_settings(default_model_path / "settings.toml") data = read_technodata( settings, sector_name="power", @@ -963,10 +931,10 @@ def test_read_technodata__trade(trade_model_path): ) -def test_read_presets_sector(model_path): - from muse.readers.toml import read_presets_sector, read_settings +def test_read_presets_sector(default_model_path): + from muse.readers.toml import read_presets_sector - settings = read_settings(model_path / "settings.toml") + settings = read_settings(default_model_path / "settings.toml") data = read_presets_sector(settings, sector_name="residential_presets") # Check data against schema @@ -1011,11 +979,11 @@ def test_read_presets_sector(model_path): assert_single_coordinate(data, coord, expected) -def test_read_presets_sector__correlation(correlation_model_path): +def test_read_presets_sector__correlation(default_correlation_model_path): """Testing the read_presets_sector function with the correlation model.""" - from muse.readers.toml import read_presets_sector, read_settings + from muse.readers.toml import read_presets_sector - settings = read_settings(correlation_model_path / "settings.toml") + settings = read_settings(default_correlation_model_path / "settings.toml") data = read_presets_sector(settings, sector_name="residential_presets") # Check data against schema diff --git a/tests/test_read_csv.py b/tests/test_read_csv.py index 4ed45c8bd..014345705 100644 --- a/tests/test_read_csv.py +++ b/tests/test_read_csv.py @@ -1,52 +1,11 @@ from __future__ import annotations -from pytest import fixture -from muse import examples -from muse.readers.toml import read_settings - - -@fixture -def model_path(tmp_path): - """Creates temporary folder containing the default model.""" - examples.copy_model(name="default", path=tmp_path) - path = tmp_path / "model" - read_settings(path / "settings.toml") # setup globals - return path - - -@fixture -def timeslice_model_path(tmp_path): - """Creates temporary folder containing the default model.""" - examples.copy_model(name="default_timeslice", path=tmp_path) - path = tmp_path / "model" - read_settings(path / "settings.toml") # setup globals - return path - - -@fixture -def trade_model_path(tmp_path): - """Creates temporary folder containing the trade model.""" - examples.copy_model(name="trade", path=tmp_path) - path = tmp_path / "model" - read_settings(path / "settings.toml") # setup globals - return path - - -@fixture -def correlation_model_path(tmp_path): - """Creates temporary folder containing the correlation model.""" - examples.copy_model(name="default_correlation", path=tmp_path) - path = tmp_path / "model" - read_settings(path / "settings.toml") # setup globals - return path - - -def test_read_technodictionary_csv(model_path): +def test_read_technodictionary_csv(default_model_path): """Test reading the technodictionary CSV file.""" from muse.readers.csv.technologies import read_technodictionary_csv - technodictionary_path = model_path / "power" / "Technodata.csv" + technodictionary_path = default_model_path / "power" / "Technodata.csv" technodictionary_df = read_technodictionary_csv(technodictionary_path) assert technodictionary_df is not None mandatory_columns = { @@ -73,11 +32,13 @@ def test_read_technodictionary_csv(model_path): assert set(technodictionary_df.columns) == mandatory_columns | extra_columns -def test_read_technodata_timeslices_csv(timeslice_model_path): +def test_read_technodata_timeslices_csv(default_timeslice_model_path): """Test reading the technodata timeslices CSV file.""" from muse.readers.csv.technologies import read_technodata_timeslices_csv - timeslices_path = timeslice_model_path / "power" / "TechnodataTimeslices.csv" + timeslices_path = ( + default_timeslice_model_path / "power" / "TechnodataTimeslices.csv" + ) timeslices_df = read_technodata_timeslices_csv(timeslices_path) mandatory_columns = { "utilization_factor", @@ -94,11 +55,11 @@ def test_read_technodata_timeslices_csv(timeslice_model_path): assert set(timeslices_df.columns) == mandatory_columns | extra_columns -def test_read_initial_capacity_csv(model_path): +def test_read_initial_capacity_csv(default_model_path): """Test reading the initial capacity CSV file.""" from muse.readers.csv.assets import read_initial_capacity_csv - capacity_path = model_path / "power" / "ExistingCapacity.csv" + capacity_path = default_model_path / "power" / "ExistingCapacity.csv" capacity_df = read_initial_capacity_csv(capacity_path) mandatory_columns = { "region", @@ -116,11 +77,11 @@ def test_read_initial_capacity_csv(model_path): assert set(capacity_df.columns) == mandatory_columns | extra_columns -def test_read_global_commodities_csv(model_path): +def test_read_global_commodities_csv(default_model_path): """Test reading the global commodities CSV file.""" from muse.readers.csv.commodities import read_global_commodities_csv - commodities_path = model_path / "GlobalCommodities.csv" + commodities_path = default_model_path / "GlobalCommodities.csv" commodities_df = read_global_commodities_csv(commodities_path) mandatory_columns = { "commodity", @@ -130,12 +91,14 @@ def test_read_global_commodities_csv(model_path): assert set(commodities_df.columns) == mandatory_columns | extra_columns -def test_read_timeslice_shares_csv(correlation_model_path): +def test_read_timeslice_shares_csv(default_correlation_model_path): """Test reading the timeslice shares CSV file.""" from muse.readers.csv.regression import read_timeslice_shares_csv shares_path = ( - correlation_model_path / "residential_presets" / "TimesliceSharepreset.csv" + default_correlation_model_path + / "residential_presets" + / "TimesliceSharepreset.csv" ) shares_df = read_timeslice_shares_csv(shares_path) mandatory_columns = { @@ -152,11 +115,11 @@ def test_read_timeslice_shares_csv(correlation_model_path): assert set(shares_df.columns) == mandatory_columns | extra_columns -def test_read_agent_parameters_csv(model_path): +def test_read_agent_parameters_csv(default_model_path): """Test reading the agent parameters CSV file.""" from muse.readers.csv.agents import read_agent_parameters_csv - agents_path = model_path / "Agents.csv" + agents_path = default_model_path / "Agents.csv" agents_df = read_agent_parameters_csv(agents_path) mandatory_columns = { "search_rule", @@ -175,11 +138,13 @@ def test_read_agent_parameters_csv(model_path): assert set(agents_df.columns) == mandatory_columns | extra_columns -def test_read_macro_drivers_csv(correlation_model_path): +def test_read_macro_drivers_csv(default_correlation_model_path): """Test reading the macro drivers CSV file.""" from muse.readers.csv.regression import read_macro_drivers_csv - macro_path = correlation_model_path / "residential_presets" / "Macrodrivers.csv" + macro_path = ( + default_correlation_model_path / "residential_presets" / "Macrodrivers.csv" + ) macro_df = read_macro_drivers_csv(macro_path) mandatory_columns = { "region", @@ -194,11 +159,11 @@ def test_read_macro_drivers_csv(correlation_model_path): assert "GDP|PPP" in macro_df["variable"].values -def test_read_projections_csv(model_path): +def test_read_projections_csv(default_model_path): """Test reading the projections CSV file.""" from muse.readers.csv.market import read_projections_csv - projections_path = model_path / "Projections.csv" + projections_path = default_model_path / "Projections.csv" projections_df = read_projections_csv(projections_path) mandatory_columns = { "year", @@ -213,12 +178,14 @@ def test_read_projections_csv(model_path): assert set(projections_df.columns) == mandatory_columns | extra_columns -def test_read_regression_parameters_csv(correlation_model_path): +def test_read_regression_parameters_csv(default_correlation_model_path): """Test reading the regression parameters CSV file.""" from muse.readers.csv.regression import read_regression_parameters_csv regression_path = ( - correlation_model_path / "residential_presets" / "regressionparameters.csv" + default_correlation_model_path + / "residential_presets" + / "regressionparameters.csv" ) regression_df = read_regression_parameters_csv(regression_path) mandatory_columns = { @@ -236,11 +203,13 @@ def test_read_regression_parameters_csv(correlation_model_path): assert set(regression_df.columns) == mandatory_columns | extra_columns -def test_read_presets_csv(model_path): +def test_read_presets_csv(default_model_path): """Test reading the presets CSV files.""" from muse.readers.csv.presets import read_presets_csv - presets_path = model_path / "residential_presets" / "Residential2020Consumption.csv" + presets_path = ( + default_model_path / "residential_presets" / "Residential2020Consumption.csv" + ) presets_df = read_presets_csv(presets_path) mandatory_columns = { diff --git a/tests/test_wizard.py b/tests/test_wizard.py index da3e29ca3..1dcdfd07a 100644 --- a/tests/test_wizard.py +++ b/tests/test_wizard.py @@ -1,10 +1,8 @@ from pathlib import Path import pandas as pd -import pytest from tomlkit import dumps, parse -from muse import examples from muse.wizard import ( add_agent, add_new_commodity, @@ -17,20 +15,6 @@ ) -@pytest.fixture -def model_path(tmp_path): - """Creates temporary folder containing the default model.""" - examples.copy_model(name="default", path=tmp_path) - return tmp_path / "model" - - -@pytest.fixture -def model_path_retro(tmp_path): - """Creates temporary folder containing the default_retro model.""" - examples.copy_model(name="default_retro", path=tmp_path) - return tmp_path / "model" - - def assert_values_in_csv(file_path: Path, column: str, expected_values: list): """Helper function to check if values exist in a CSV column.""" df = pd.read_csv(file_path) @@ -75,19 +59,19 @@ def test_get_sectors(tmp_path): assert set(get_sectors(model_path)) == {"sector1", "sector2"} -def test_add_new_commodity(model_path): +def test_add_new_commodity(default_model_path): """Test the add_new_commodity function on the default model.""" - add_new_commodity(model_path, "new_commodity", "power", "wind") + add_new_commodity(default_model_path, "new_commodity", "power", "wind") # Check global commodities assert_values_in_csv( - model_path / "GlobalCommodities.csv", "commodity", ["new_commodity"] + default_model_path / "GlobalCommodities.csv", "commodity", ["new_commodity"] ) -def test_add_new_process(model_path): +def test_add_new_process(default_model_path): """Test the add_new_process function on the default model.""" - add_new_process(model_path, "new_process", "power", "windturbine") + add_new_process(default_model_path, "new_process", "power", "windturbine") files_to_check = [ "CommIn.csv", @@ -96,36 +80,44 @@ def test_add_new_process(model_path): "Technodata.csv", ] for file in files_to_check: - assert_values_in_csv(model_path / "power" / file, "technology", ["new_process"]) + assert_values_in_csv( + default_model_path / "power" / file, "technology", ["new_process"] + ) -def test_technodata_for_new_year(model_path): +def test_technodata_for_new_year(default_model_path): """Test the add_price_data_for_new_year function on the default model.""" - add_technodata_for_new_year(model_path, 2030, "power", 2020) + add_technodata_for_new_year(default_model_path, 2030, "power", 2020) - assert_values_in_csv(model_path / "power" / "Technodata.csv", "year", [2030]) + assert_values_in_csv( + default_model_path / "power" / "Technodata.csv", "year", [2030] + ) -def test_add_agent(model_path_retro): +def test_add_agent(default_retro_model_path): """Test the add_agent function on the default_retro model.""" - add_agent(model_path_retro, "A2", "A1", "Agent3", "Agent4") + add_agent(default_retro_model_path, "A2", "A1", "Agent3", "Agent4") # Check Agents.csv - assert_values_in_csv(model_path_retro / "Agents.csv", "name", ["A2"]) + assert_values_in_csv(default_retro_model_path / "Agents.csv", "name", ["A2"]) for share in ["Agent3", "Agent4"]: - assert_values_in_csv(model_path_retro / "Agents.csv", "agent_share", [share]) + assert_values_in_csv( + default_retro_model_path / "Agents.csv", "agent_share", [share] + ) # Check Technodata.csv files for sector in ["power", "gas"]: - assert_columns_exist(model_path_retro / sector / "Technodata.csv", ["Agent4"]) + assert_columns_exist( + default_retro_model_path / sector / "Technodata.csv", ["Agent4"] + ) -def test_add_region(model_path): +def test_add_region(default_model_path): """Test the add_region function on the default model.""" - add_region(model_path, "R2", "R1") + add_region(default_model_path, "R2", "R1") # Check settings.toml - with open(model_path / "settings.toml") as f: + with open(default_model_path / "settings.toml") as f: settings = parse(f.read()) assert "R2" in settings["regions"] @@ -136,17 +128,17 @@ def test_add_region(model_path): "CommOut.csv", "ExistingCapacity.csv", ] - for sector in get_sectors(model_path): + for sector in get_sectors(default_model_path): for file in files_to_check: - assert_values_in_csv(model_path / sector / file, "region", ["R2"]) + assert_values_in_csv(default_model_path / sector / file, "region", ["R2"]) -def test_add_timeslice(model_path): +def test_add_timeslice(default_model_path): """Test the add_timeslice function on the default model.""" - add_timeslice(model_path, "midnight", "evening") + add_timeslice(default_model_path, "midnight", "evening") # Check settings.toml - with open(model_path / "settings.toml") as f: + with open(default_model_path / "settings.toml") as f: settings = parse(f.read()) timeslices = settings["timeslices"]["all-year"]["all-week"] assert "midnight" in timeslices @@ -154,5 +146,5 @@ def test_add_timeslice(model_path): # Check preset files for preset in ["Residential2020Consumption.csv", "Residential2050Consumption.csv"]: - df = pd.read_csv(model_path / "residential_presets" / preset) + df = pd.read_csv(default_model_path / "residential_presets" / preset) assert len(df["timeslice"].unique()) == n_timeslices From 88e77fc51818e538a3fa8bb2df7db14541927ab6 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Mon, 11 Aug 2025 10:53:38 +0100 Subject: [PATCH 6/7] Rename function --- src/muse/examples.py | 4 ++-- src/muse/readers/toml.py | 2 +- src/muse/sectors/sector.py | 4 ++-- tests/test_csv_readers.py | 12 ++++++------ 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/muse/examples.py b/src/muse/examples.py index 189a82c85..7cd2a426b 100644 --- a/src/muse/examples.py +++ b/src/muse/examples.py @@ -144,7 +144,7 @@ def technodata(sector: str, model: str = "default") -> xr.Dataset: """Technology for a sector of a given example model.""" from tempfile import TemporaryDirectory - from muse.readers.toml import read_settings, read_technodata + from muse.readers.toml import read_sector_technodata, read_settings sector = sector.lower() allowed = {"residential", "power", "gas", "preset"} @@ -155,7 +155,7 @@ def technodata(sector: str, model: str = "default") -> xr.Dataset: with TemporaryDirectory() as tmpdir: path = copy_model(model, tmpdir) settings = read_settings(path / "settings.toml") - return read_technodata(settings, sector) + return read_sector_technodata(settings, sector) def search_space(sector: str, model: str = "default") -> xr.DataArray: diff --git a/src/muse/readers/toml.py b/src/muse/readers/toml.py index 5c49320d5..50ded22f9 100644 --- a/src/muse/readers/toml.py +++ b/src/muse/readers/toml.py @@ -556,7 +556,7 @@ def check_subsector_settings(settings: dict) -> None: getLogger(__name__).warning(msg) -def read_technodata( +def read_sector_technodata( settings: Any, sector_name: str, interpolation_mode: str = "linear", diff --git a/src/muse/sectors/sector.py b/src/muse/sectors/sector.py index 6d7efa077..b12f6e912 100644 --- a/src/muse/sectors/sector.py +++ b/src/muse/sectors/sector.py @@ -11,7 +11,7 @@ from muse.agents import AbstractAgent from muse.production import PRODUCTION_SIGNATURE -from muse.readers.toml import read_technodata +from muse.readers.toml import read_sector_technodata from muse.sectors.abstract import AbstractSector from muse.sectors.register import register_sector from muse.sectors.subsector import Subsector @@ -44,7 +44,7 @@ def factory(cls, name: str, settings: Any) -> Sector: interactions_config = sector_settings.get("interactions", None) # Read technologies - technologies = read_technodata( + technologies = read_sector_technodata( settings, name, interpolation_mode=interpolation_mode, diff --git a/tests/test_csv_readers.py b/tests/test_csv_readers.py index 6feb9ca8d..b3fb39bb4 100644 --- a/tests/test_csv_readers.py +++ b/tests/test_csv_readers.py @@ -818,11 +818,11 @@ def test_read_technologies__timeslice(default_timeslice_model_path): ) -def test_read_technodata(default_model_path): - from muse.readers.toml import read_settings, read_technodata +def test_read_sector_technodata(default_model_path): + from muse.readers.toml import read_sector_technodata, read_settings settings = read_settings(default_model_path / "settings.toml") - data = read_technodata( + data = read_sector_technodata( settings, sector_name="power", interpolation_mode="linear", @@ -873,12 +873,12 @@ def test_read_technodata(default_model_path): ) -def test_read_technodata__trade(trade_model_path): +def test_read_sector_technodata__trade(trade_model_path): """Testing the read_technodata function with the trade model.""" - from muse.readers.toml import read_settings, read_technodata + from muse.readers.toml import read_sector_technodata, read_settings settings = read_settings(trade_model_path / "settings.toml") - data = read_technodata( + data = read_sector_technodata( settings, sector_name="power", interpolation_mode="linear", From eaef8f8c1bfa4745e97b4722be62de0e824210d1 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Mon, 11 Aug 2025 16:41:12 +0100 Subject: [PATCH 7/7] Renaming functions, a bit of restructuring --- src/muse/__init__.py | 8 +- src/muse/agents/factories.py | 9 +- src/muse/commodities.py | 4 +- src/muse/readers/csv/__init__.py | 56 ++++------ src/muse/readers/csv/agents.py | 16 ++- src/muse/readers/csv/assets.py | 17 ++- src/muse/readers/csv/commodities.py | 13 ++- ...gression.py => correlation_consumption.py} | 50 +++++++++ src/muse/readers/csv/helpers.py | 2 + src/muse/readers/csv/market.py | 13 +++ src/muse/readers/csv/presets.py | 5 + src/muse/readers/csv/technologies.py | 103 ++++++++++-------- .../readers/csv/{trade.py => trade_assets.py} | 52 +-------- src/muse/readers/csv/trade_technodata.py | 60 ++++++++++ src/muse/readers/toml.py | 60 ++-------- src/muse/regressions.py | 25 +---- src/muse/sectors/subsector.py | 6 +- tests/test_csv_readers.py | 32 +++--- tests/test_read_csv.py | 22 ++-- tests/test_subsector.py | 6 +- 20 files changed, 306 insertions(+), 253 deletions(-) rename src/muse/readers/csv/{regression.py => correlation_consumption.py} (75%) rename src/muse/readers/csv/{trade.py => trade_assets.py} (53%) create mode 100644 src/muse/readers/csv/trade_technodata.py diff --git a/src/muse/__init__.py b/src/muse/__init__.py index cac496a06..a55470a7b 100644 --- a/src/muse/__init__.py +++ b/src/muse/__init__.py @@ -96,14 +96,10 @@ def add_file_logger() -> None: "investments", "objectives", "outputs", - "read_agent_parameters", + "read_agents", + "read_existing_capacity", "read_global_commodities", - "read_initial_capacity", - "read_io_technodata", - "read_macro_drivers", "read_settings", - "read_technodictionary", "read_technologies", - "read_timeslice_shares", "sectors", ] diff --git a/src/muse/agents/factories.py b/src/muse/agents/factories.py index a9137a5ad..c6479874c 100644 --- a/src/muse/agents/factories.py +++ b/src/muse/agents/factories.py @@ -167,7 +167,7 @@ def create_agent(agent_type: str, **kwargs) -> Agent: def agents_factory( - params_or_path: str | Path | list, + path: Path, capacity: xr.DataArray, technologies: xr.Dataset, regions: Sequence[str] | None = None, @@ -178,12 +178,9 @@ def agents_factory( from copy import deepcopy from logging import getLogger - from muse.readers import read_agent_parameters + from muse.readers import read_agents - if isinstance(params_or_path, (str, Path)): - params = read_agent_parameters(params_or_path) - else: - params = params_or_path + params = read_agents(path) assert isinstance(capacity, xr.DataArray) if year is None: year = int(capacity.year.min()) diff --git a/src/muse/commodities.py b/src/muse/commodities.py index 517fd81e6..3874475b8 100644 --- a/src/muse/commodities.py +++ b/src/muse/commodities.py @@ -14,10 +14,10 @@ def setup_module(commodities_path: Path): """Sets up global commodities.""" - from muse.readers.csv import read_global_commodities + from muse.readers.csv import read_commodities global COMMODITIES - COMMODITIES = read_global_commodities(commodities_path) + COMMODITIES = read_commodities(commodities_path) class CommodityUsage(IntFlag): diff --git a/src/muse/readers/csv/__init__.py b/src/muse/readers/csv/__init__.py index d4a46985c..a57b25377 100644 --- a/src/muse/readers/csv/__init__.py +++ b/src/muse/readers/csv/__init__.py @@ -15,53 +15,45 @@ (usually an xarray). There are also some more checks (e.g. checking for nan values). -Most of the processing is shared by a few helper functions: -- read_csv: reads a csv file and returns a dataframe -- standardize_dataframe: standardizes the dataframe to a common format -- create_multiindex: creates a multiindex from a dataframe -- create_xarray_dataset: creates an xarray dataset from a dataframe +The code in this module is spread over multiple files. In general, we have one `read_x` +function per file, and as many `read_x_csv` and `process_x` functions as are required +(e.g. if a dataset is assembled from three csv files we will have three `read_x_csv` +functions, and potentially multiple `process_x` functions). + +Most of the processing is shared by a few helper functions (in `helpers.py`): +- `read_csv`: reads a csv file and returns a dataframe +- `standardize_dataframe`: standardizes the dataframe to a common format +- `create_multiindex`: creates a multiindex from a dataframe +- `create_xarray_dataset`: creates an xarray dataset from a dataframe A few other helpers perform common operations on xarrays: -- create_assets: creates assets from technologies -- check_commodities: checks commodities and fills missing values +- `create_assets`: creates assets from technologies +- `check_commodities`: checks commodities and fills missing values """ -from .agents import read_agent_parameters -from .assets import read_initial_capacity -from .commodities import read_global_commodities +from .agents import read_agents +from .assets import read_assets +from .commodities import read_commodities +from .correlation_consumption import read_correlation_consumption from .general import read_attribute_table from .helpers import read_csv from .market import read_initial_market from .presets import read_presets -from .regression import ( - read_macro_drivers, - read_regression_parameters, - read_timeslice_shares, -) -from .technologies import ( - read_io_technodata, - read_technodata_timeslices, - read_technodictionary, - read_technologies, -) -from .trade import read_existing_trade, read_trade_technodata +from .technologies import read_technologies +from .trade_assets import read_trade_assets +from .trade_technodata import read_trade_technodata __all__ = [ - "read_agent_parameters", + "read_agents", + "read_assets", "read_attribute_table", + "read_commodities", + "read_correlation_consumption", "read_csv", - "read_existing_trade", - "read_global_commodities", - "read_initial_capacity", "read_initial_market", - "read_io_technodata", - "read_macro_drivers", "read_presets", - "read_regression_parameters", - "read_technodata_timeslices", - "read_technodictionary", "read_technologies", - "read_timeslice_shares", + "read_trade_assets", "read_trade_technodata", ] diff --git a/src/muse/readers/csv/agents.py b/src/muse/readers/csv/agents.py index 947e679de..99046f237 100644 --- a/src/muse/readers/csv/agents.py +++ b/src/muse/readers/csv/agents.py @@ -1,3 +1,9 @@ +"""Reads and processes agent parameters from a CSV file. + +This runs once per subsector, reading in a csv file and outputting a list of +dictionaries (one dictionary per agent containing the agent's parameters). +""" + from logging import getLogger from pathlib import Path @@ -6,13 +12,13 @@ from .helpers import read_csv -def read_agent_parameters(path: Path) -> pd.DataFrame: +def read_agents(path: Path) -> list[dict]: """Reads and processes agent parameters from a CSV file.""" - df = read_agent_parameters_csv(path) - return process_agent_parameters(df) + df = read_agents_csv(path) + return process_agents(df) -def read_agent_parameters_csv(path: Path) -> pd.DataFrame: +def read_agents_csv(path: Path) -> pd.DataFrame: """Reads standard MUSE agent-declaration csv-files into a DataFrame.""" required_columns = { "search_rule", @@ -57,7 +63,7 @@ def read_agent_parameters_csv(path: Path) -> pd.DataFrame: return data -def process_agent_parameters(data: pd.DataFrame) -> list[dict]: +def process_agents(data: pd.DataFrame) -> list[dict]: """Processes agent parameters DataFrame into a list of agent dictionaries.""" result = [] for _, row in data.iterrows(): diff --git a/src/muse/readers/csv/assets.py b/src/muse/readers/csv/assets.py index e5abeb1e9..0fd4090dc 100644 --- a/src/muse/readers/csv/assets.py +++ b/src/muse/readers/csv/assets.py @@ -1,3 +1,8 @@ +"""Reads and processes existing capacity data from a CSV file. + +This runs once per subsector, reading in a csv file and outputting an xarray DataArray. +""" + from pathlib import Path import pandas as pd @@ -6,13 +11,13 @@ from .helpers import create_assets, create_multiindex, create_xarray_dataset, read_csv -def read_initial_capacity(path: Path) -> xr.DataArray: - """Reads and processes initial capacity data from a CSV file.""" - df = read_initial_capacity_csv(path) - return process_initial_capacity(df) +def read_assets(path: Path) -> xr.DataArray: + """Reads and processes existing capacity data from a CSV file.""" + df = read_existing_capacity_csv(path) + return process_existing_capacity(df) -def read_initial_capacity_csv(path: Path) -> pd.DataFrame: +def read_existing_capacity_csv(path: Path) -> pd.DataFrame: """Reads and formats data about initial capacity into a DataFrame.""" required_columns = { "region", @@ -25,7 +30,7 @@ def read_initial_capacity_csv(path: Path) -> pd.DataFrame: ) -def process_initial_capacity(data: pd.DataFrame) -> xr.DataArray: +def process_existing_capacity(data: pd.DataFrame) -> xr.DataArray: """Processes initial capacity DataFrame into an xarray DataArray.""" # Drop unit column if present if "unit" in data.columns: diff --git a/src/muse/readers/csv/commodities.py b/src/muse/readers/csv/commodities.py index 8b3c25df0..e07c4fc69 100644 --- a/src/muse/readers/csv/commodities.py +++ b/src/muse/readers/csv/commodities.py @@ -1,3 +1,14 @@ +"""Reads and processes global commodities data from a CSV file. + +This runs once per simulation, reading in a csv file and outputting an xarray Dataset. + +The CSV file will generally have the following columns: `commodity`, `commodity_type`, +unit (optional), description (optional). + +The resulting xarray Dataset will have a single `commodity` dimension, and variables +for `commodity_type` and `unit`. +""" + from logging import getLogger from pathlib import Path @@ -7,7 +18,7 @@ from .helpers import camel_to_snake, create_xarray_dataset, standardize_dataframe -def read_global_commodities(path: Path) -> xr.Dataset: +def read_commodities(path: Path) -> xr.Dataset: """Reads and processes global commodities data from a CSV file.""" df = read_global_commodities_csv(path) return process_global_commodities(df) diff --git a/src/muse/readers/csv/regression.py b/src/muse/readers/csv/correlation_consumption.py similarity index 75% rename from src/muse/readers/csv/regression.py rename to src/muse/readers/csv/correlation_consumption.py index 1ab78e75f..21d5f4106 100644 --- a/src/muse/readers/csv/regression.py +++ b/src/muse/readers/csv/correlation_consumption.py @@ -1,3 +1,9 @@ +"""Reads and processes correlation consumption data from CSV files. + +This will only run for preset sectors that have macro drivers and regression files. +Otherwise, read_presets will be used instead. +""" + from __future__ import annotations from logging import getLogger @@ -14,6 +20,50 @@ ) +def read_correlation_consumption( + macro_drivers_path: Path, + regression_path: Path, + timeslice_shares_path: Path | None = None, +) -> xr.Dataset: + """Read consumption data for a sector based on correlation files. + + This function calculates endogenous demand for a sector using macro drivers and + regression parameters. It applies optional filters, handles sector aggregation, + and distributes the consumption across timeslices if timeslice shares are provided. + + Args: + macro_drivers_path: Path to macro drivers file + regression_path: Path to regression parameters file + timeslice_shares_path: Path to timeslice shares file (optional) + + Returns: + xr.Dataset: Consumption data distributed across timeslices and regions + """ + from muse.regressions import endogenous_demand + from muse.timeslices import broadcast_timeslice, distribute_timeslice + + macro_drivers = read_macro_drivers(macro_drivers_path) + regression_parameters = read_regression_parameters(regression_path) + consumption = endogenous_demand( + drivers=macro_drivers, + regression_parameters=regression_parameters, + forecast=0, + ) + + # Legacy: we permit regression parameters to split by sector, so have to sum + if "sector" in consumption.dims: + consumption = consumption.sum("sector") + + # Split by timeslice + if timeslice_shares_path is not None: + shares = read_timeslice_shares(timeslice_shares_path) + consumption = broadcast_timeslice(consumption) * shares + else: + consumption = distribute_timeslice(consumption) + + return consumption + + def read_timeslice_shares(path: Path) -> xr.DataArray: """Reads and processes timeslice shares data from a CSV file.""" df = read_timeslice_shares_csv(path) diff --git a/src/muse/readers/csv/helpers.py b/src/muse/readers/csv/helpers.py index c5909572a..e74e4611f 100644 --- a/src/muse/readers/csv/helpers.py +++ b/src/muse/readers/csv/helpers.py @@ -1,3 +1,5 @@ +"""Helper functions for reading and processing CSV files.""" + from __future__ import annotations from logging import getLogger diff --git a/src/muse/readers/csv/market.py b/src/muse/readers/csv/market.py index 3549669fc..add03faee 100644 --- a/src/muse/readers/csv/market.py +++ b/src/muse/readers/csv/market.py @@ -1,3 +1,16 @@ +"""Reads and processes initial market data. + +The data is shared between sectors, so we only do this once per simulation. + +This data is contained in three csv files: +- projections: contains price projections for commodities +- base year import (optional): contains imports for commodities +- base year export (optional): contains exports for commodities + +A single xarray Dataset is returned, with dimensions for `region`, `year`, `commodity`, +and `timeslice`, and variables for `prices`, `exports`, `imports`, and `static_trade`. +""" + from __future__ import annotations from pathlib import Path diff --git a/src/muse/readers/csv/presets.py b/src/muse/readers/csv/presets.py index a9efe7043..abb68a6ac 100644 --- a/src/muse/readers/csv/presets.py +++ b/src/muse/readers/csv/presets.py @@ -1,3 +1,8 @@ +"""Reads and processes preset data from multiple CSV files. + +This runs once per sector, reading in csv files and outputting an xarray Dataset. +""" + from logging import getLogger from pathlib import Path diff --git a/src/muse/readers/csv/technologies.py b/src/muse/readers/csv/technologies.py index 680f2036f..a786a06d6 100644 --- a/src/muse/readers/csv/technologies.py +++ b/src/muse/readers/csv/technologies.py @@ -1,3 +1,14 @@ +"""Reads and processes technology data from multiple CSV files. + +This runs once per sector, reading in csv files and outputting an xarray Dataset. + +Several csv files are read in: +- technodictionary: contains technology parameters +- comm_out: contains output commodity data +- comm_in: contains input commodity data +- technodata_timeslices (optional): allows some parameters to be defined per timeslice +""" + from __future__ import annotations from logging import getLogger @@ -14,6 +25,52 @@ ) +def read_technologies( + technodata_path: Path, + comm_out_path: Path, + comm_in_path: Path, + time_framework: list[int], + interpolation_mode: str = "linear", + technodata_timeslices_path: Path | None = None, +) -> xr.Dataset: + """Reads and processes technology data from multiple CSV files. + + Will also interpolate data to the time framework if provided. + + Args: + technodata_path: path to the technodata file + comm_out_path: path to the comm_out file + comm_in_path: path to the comm_in file + time_framework: list of years to interpolate data to + interpolation_mode: Interpolation mode to use + technodata_timeslices_path: path to the technodata_timeslices file + + Returns: + xr.Dataset: Dataset containing the processed technology data. Any fields + that differ by year will contain a "year" dimension interpolated to the + time framework. Other fields will not have a "year" dimension. + """ + # Read all data + technodata = read_technodictionary(technodata_path) + comm_out = read_io_technodata(comm_out_path) + comm_in = read_io_technodata(comm_in_path) + technodata_timeslices = ( + read_technodata_timeslices(technodata_timeslices_path) + if technodata_timeslices_path + else None + ) + + # Assemble xarray Dataset + return process_technologies( + technodata, + comm_out, + comm_in, + time_framework, + interpolation_mode, + technodata_timeslices, + ) + + def read_technodictionary(path: Path) -> xr.Dataset: """Reads and processes technodictionary data from a CSV file.""" df = read_technodictionary_csv(path) @@ -199,52 +256,6 @@ def process_io_technodata(data: pd.DataFrame) -> xr.Dataset: return result -def read_technologies( - technodata_path: Path, - comm_out_path: Path, - comm_in_path: Path, - time_framework: list[int], - interpolation_mode: str = "linear", - technodata_timeslices_path: Path | None = None, -) -> xr.Dataset: - """Reads and processes technology data from multiple CSV files. - - Will also interpolate data to the time framework if provided. - - Args: - technodata_path: path to the technodata file - comm_out_path: path to the comm_out file - comm_in_path: path to the comm_in file - time_framework: list of years to interpolate data to - interpolation_mode: Interpolation mode to use - technodata_timeslices_path: path to the technodata_timeslices file - - Returns: - xr.Dataset: Dataset containing the processed technology data. Any fields - that differ by year will contain a "year" dimension interpolated to the - time framework. Other fields will not have a "year" dimension. - """ - # Read all data - technodata = read_technodictionary(technodata_path) - comm_out = read_io_technodata(comm_out_path) - comm_in = read_io_technodata(comm_in_path) - technodata_timeslices = ( - read_technodata_timeslices(technodata_timeslices_path) - if technodata_timeslices_path - else None - ) - - # Assemble xarray Dataset - return process_technologies( - technodata, - comm_out, - comm_in, - time_framework, - interpolation_mode, - technodata_timeslices, - ) - - def process_technologies( technodata: xr.Dataset, comm_out: xr.Dataset, diff --git a/src/muse/readers/csv/trade.py b/src/muse/readers/csv/trade_assets.py similarity index 53% rename from src/muse/readers/csv/trade.py rename to src/muse/readers/csv/trade_assets.py index 29efa305a..652f1ad3b 100644 --- a/src/muse/readers/csv/trade.py +++ b/src/muse/readers/csv/trade_assets.py @@ -1,3 +1,8 @@ +"""Reads and processes existing trade data from a CSV file. + +We only use this for trade sectors, otherwise we use read_assets instead. +""" + from pathlib import Path import pandas as pd @@ -11,52 +16,7 @@ ) -def read_trade_technodata(path: Path) -> xr.Dataset: - """Reads and processes trade technodata from a CSV file.""" - df = read_trade_technodata_csv(path) - return process_trade_technodata(df) - - -def read_trade_technodata_csv(path: Path) -> pd.DataFrame: - required_columns = {"technology", "region", "parameter"} - return read_csv( - path, - required_columns=required_columns, - msg=f"Reading trade technodata from {path}.", - ) - - -def process_trade_technodata(data: pd.DataFrame) -> xr.Dataset: - # Drop unit column if present - if "unit" in data.columns: - data = data.drop(columns=["unit"]) - - # Select region columns - # TODO: this is a bit unsafe as user could supply other columns - regions = [ - col for col in data.columns if col not in ["technology", "region", "parameter"] - ] - - # Melt data over regions - data = data.melt( - id_vars=["technology", "region", "parameter"], - value_vars=regions, - var_name="dst_region", - value_name="value", - ) - - # Pivot data over parameters - data = data.pivot( - index=["technology", "region", "dst_region"], - columns="parameter", - values="value", - ) - - # Create DataSet - return create_xarray_dataset(data) - - -def read_existing_trade(path: Path) -> xr.DataArray: +def read_trade_assets(path: Path) -> xr.DataArray: """Reads and processes existing trade data from a CSV file.""" df = read_existing_trade_csv(path) return process_existing_trade(df) diff --git a/src/muse/readers/csv/trade_technodata.py b/src/muse/readers/csv/trade_technodata.py new file mode 100644 index 000000000..69240e952 --- /dev/null +++ b/src/muse/readers/csv/trade_technodata.py @@ -0,0 +1,60 @@ +"""Reads and processes trade technodata from a CSV file. + +We only use this for trade sectors. In this case, it gets added on the the dataset +created by `read_technologies`. +""" + +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .helpers import ( + create_xarray_dataset, + read_csv, +) + + +def read_trade_technodata(path: Path) -> xr.Dataset: + """Reads and processes trade technodata from a CSV file.""" + df = read_trade_technodata_csv(path) + return process_trade_technodata(df) + + +def read_trade_technodata_csv(path: Path) -> pd.DataFrame: + required_columns = {"technology", "region", "parameter"} + return read_csv( + path, + required_columns=required_columns, + msg=f"Reading trade technodata from {path}.", + ) + + +def process_trade_technodata(data: pd.DataFrame) -> xr.Dataset: + # Drop unit column if present + if "unit" in data.columns: + data = data.drop(columns=["unit"]) + + # Select region columns + # TODO: this is a bit unsafe as user could supply other columns + regions = [ + col for col in data.columns if col not in ["technology", "region", "parameter"] + ] + + # Melt data over regions + data = data.melt( + id_vars=["technology", "region", "parameter"], + value_vars=regions, + var_name="dst_region", + value_name="value", + ) + + # Pivot data over parameters + data = data.pivot( + index=["technology", "region", "dst_region"], + columns="parameter", + values="value", + ) + + # Create DataSet + return create_xarray_dataset(data) diff --git a/src/muse/readers/toml.py b/src/muse/readers/toml.py index 50ded22f9..06020ebd0 100644 --- a/src/muse/readers/toml.py +++ b/src/muse/readers/toml.py @@ -646,7 +646,11 @@ def read_presets_sector(settings: Any, sector_name: str) -> xr.Dataset: xr.Dataset: Dataset containing consumption and supply data for the sector. Costs are initialized to zero. """ - from muse.readers import read_attribute_table, read_presets + from muse.readers import ( + read_attribute_table, + read_correlation_consumption, + read_presets, + ) from muse.timeslices import distribute_timeslice, drop_timeslice sector_conf = getattr(settings.sectors, sector_name) @@ -662,7 +666,11 @@ def read_presets_sector(settings: Any, sector_name: str) -> xr.Dataset: getattr(sector_conf, "macrodrivers_path", None) is not None and getattr(sector_conf, "regression_path", None) is not None ): - consumption = read_correlation_consumption(sector_conf) + consumption = read_correlation_consumption( + macro_drivers_path=sector_conf.macrodrivers_path, + regression_path=sector_conf.regression_path, + timeslice_shares_path=getattr(sector_conf, "timeslice_shares_path", None), + ) else: raise MissingSettings(f"Missing consumption data for sector {sector_name}") @@ -678,51 +686,3 @@ def read_presets_sector(settings: Any, sector_name: str) -> xr.Dataset: ) return presets - - -def read_correlation_consumption(sector_conf: Any) -> xr.Dataset: - """Read consumption data for a sector based on correlation files. - - This function calculates endogenous demand for a sector using macro drivers and - regression parameters. It applies optional filters, handles sector aggregation, - and distributes the consumption across timeslices if timeslice shares are provided. - - Args: - sector_conf: Sector configuration object containing paths to macro drivers, - regression parameters, and timeslice shares files - - Returns: - xr.Dataset: Consumption data distributed across timeslices and regions - """ - from muse.readers import ( - read_macro_drivers, - read_regression_parameters, - read_timeslice_shares, - ) - from muse.regressions import endogenous_demand - from muse.timeslices import broadcast_timeslice, distribute_timeslice - - macro_drivers = read_macro_drivers(sector_conf.macrodrivers_path) - regression_parameters = read_regression_parameters(sector_conf.regression_path) - consumption = endogenous_demand( - drivers=macro_drivers, - regression_parameters=regression_parameters, - forecast=0, - ) - - # Legacy: apply filters - if hasattr(sector_conf, "filters"): - consumption = consumption.sel(sector_conf.filters._asdict()) - - # Legacy: we permit regression parameters to split by sector, so have to sum - if "sector" in consumption.dims: - consumption = consumption.sum("sector") - - # Split by timeslice - if sector_conf.timeslice_shares_path is not None: - shares = read_timeslice_shares(sector_conf.timeslice_shares_path) - consumption = broadcast_timeslice(consumption) * shares - else: - consumption = distribute_timeslice(consumption) - - return consumption diff --git a/src/muse/regressions.py b/src/muse/regressions.py index 39560abe3..26dc2cd2e 100644 --- a/src/muse/regressions.py +++ b/src/muse/regressions.py @@ -48,7 +48,7 @@ class Regression(Callable): `xarray.Dataset` as input. In any case, it is given the gpd and population. These can be read from standard MUSE csv files: - >>> from muse.readers import read_macro_drivers + >>> from muse.readers.csv.correlation_consumption import read_macro_drivers >>> from muse.defaults import DATA_DIRECTORY >>> path_to_macrodrivers = DATA_DIRECTORY / "Macrodrivers.csv" >>> if path_to_macrodrivers.exists(): @@ -117,20 +117,15 @@ def _split_kwargs(data: Dataset, **kwargs) -> tuple[Mapping, Mapping]: @classmethod def factory( cls, - regression_data: str | Path | Dataset, + regression_data: Dataset, interpolation: str = "linear", base_year: int = 2010, **filters, ) -> Regression: """Creates a regression function from standard muse input.""" - from muse.readers import read_regression_parameters - assert cls.__mappings__ assert cls.__regression__ != "" - if isinstance(regression_data, (str, Path)): - regression_data = read_regression_parameters(regression_data) - # Get the parameters of interest with a 'simple' name coeffs = Dataset({k: regression_data[v] for k, v in cls.__mappings__.items()}) filters.update(coeffs.data_vars) @@ -138,15 +133,10 @@ def factory( def factory( - regression_parameters: str | Path | Dataset, + regression_parameters: Dataset, sector: str | Sequence[str] | None = None, ) -> Regression: """Creates regression functor from standard MUSE data for given sector.""" - from muse.readers import read_regression_parameters - - if isinstance(regression_parameters, (str, Path)): - regression_parameters = read_regression_parameters(regression_parameters) - if sector is not None: regression_parameters = regression_parameters.sel(sector=sector) @@ -203,7 +193,6 @@ def register_regression( registered regression functor. """ from logging import getLogger - from pathlib import Path from muse.registration import name_variations @@ -504,15 +493,11 @@ def __call__( def endogenous_demand( - regression_parameters: str | Path | Dataset, - drivers: str | Path | Dataset, + regression_parameters: Dataset, + drivers: Dataset, sector: str | Sequence | None = None, **kwargs, ) -> Dataset: """Endogenous demand based on macro drivers and regression parameters.""" - from muse.readers import read_macro_drivers - regression = factory(regression_parameters, sector=sector) - if isinstance(drivers, (str, Path)): - drivers = read_macro_drivers(drivers) return regression(drivers, **kwargs) diff --git a/src/muse/sectors/subsector.py b/src/muse/sectors/subsector.py index b7903178c..14f2698aa 100644 --- a/src/muse/sectors/subsector.py +++ b/src/muse/sectors/subsector.py @@ -118,7 +118,7 @@ def factory( from muse import investments as iv from muse.agents import InvestingAgent, agents_factory from muse.commodities import is_enduse - from muse.readers import read_csv, read_existing_trade, read_initial_capacity + from muse.readers import read_assets, read_csv, read_trade_assets # Read existing capacity or existing trade file # Have to peek at the file to determine what format the data is in @@ -126,9 +126,9 @@ def factory( # the parameter name in the settings file df = read_csv(settings.existing_capacity) if "year" not in df.columns: - existing_capacity = read_initial_capacity(settings.existing_capacity) + existing_capacity = read_assets(settings.existing_capacity) else: - existing_capacity = read_existing_trade(settings.existing_capacity) + existing_capacity = read_trade_assets(settings.existing_capacity) # Create agents agents = agents_factory( diff --git a/tests/test_csv_readers.py b/tests/test_csv_readers.py index b3fb39bb4..8eb3de513 100644 --- a/tests/test_csv_readers.py +++ b/tests/test_csv_readers.py @@ -126,10 +126,10 @@ def assert_single_coordinate(data, selection, expected): def test_read_global_commodities(default_model_path): - from muse.readers.csv import read_global_commodities + from muse.readers.csv import read_commodities path = default_model_path / "GlobalCommodities.csv" - data = read_global_commodities(path) + data = read_commodities(path) # Check data against schema expected_schema = DatasetSchema( @@ -265,7 +265,7 @@ def test_read_initial_market(default_model_path): def test_read_technodictionary(default_model_path): - from muse.readers.csv import read_technodictionary + from muse.readers.csv.technologies import read_technodictionary data = read_technodictionary(default_model_path / "power" / "Technodata.csv") @@ -335,7 +335,7 @@ def test_read_technodictionary(default_model_path): def test_read_technodata_timeslices(default_timeslice_model_path): - from muse.readers.csv import read_technodata_timeslices + from muse.readers.csv.technologies import read_technodata_timeslices data = read_technodata_timeslices( default_timeslice_model_path / "power" / "TechnodataTimeslices.csv" @@ -385,7 +385,7 @@ def test_read_technodata_timeslices(default_timeslice_model_path): def test_read_io_technodata(default_model_path): - from muse.readers.csv import read_io_technodata + from muse.readers.csv.technologies import read_io_technodata data = read_io_technodata(default_model_path / "power" / "CommIn.csv") @@ -422,10 +422,10 @@ def test_read_io_technodata(default_model_path): assert_single_coordinate(data, coord, expected) -def test_read_initial_capacity(default_model_path): - from muse.readers.csv import read_initial_capacity +def test_read_existing_capacity(default_model_path): + from muse.readers.csv import read_assets - data = read_initial_capacity(default_model_path / "power" / "ExistingCapacity.csv") + data = read_assets(default_model_path / "power" / "ExistingCapacity.csv") # Check data against schema expected_schema = DataArraySchema( @@ -458,10 +458,10 @@ def test_read_initial_capacity(default_model_path): assert data.sel(region="r1", asset=0, year=2020).item() == 1 -def test_read_agent_parameters(default_model_path): - from muse.readers.csv import read_agent_parameters +def test_read_agents(default_model_path): + from muse.readers.csv import read_agents - data = read_agent_parameters(default_model_path / "Agents.csv") + data = read_agents(default_model_path / "Agents.csv") assert isinstance(data, list) assert len(data) == 1 @@ -482,9 +482,9 @@ def test_read_agent_parameters(default_model_path): def test_read_existing_trade(trade_model_path): - from muse.readers.csv import read_existing_trade + from muse.readers.csv import read_trade_assets - data = read_existing_trade(trade_model_path / "gas" / "ExistingTrade.csv") + data = read_trade_assets(trade_model_path / "gas" / "ExistingTrade.csv") # Check data against schema expected_schema = DataArraySchema( @@ -563,7 +563,7 @@ def test_read_trade_technodata(trade_model_path): def test_read_timeslice_shares(default_correlation_model_path): - from muse.readers.csv import read_timeslice_shares + from muse.readers.csv.correlation_consumption import read_timeslice_shares data = read_timeslice_shares( default_correlation_model_path @@ -607,7 +607,7 @@ def test_read_timeslice_shares(default_correlation_model_path): def test_read_macro_drivers(default_correlation_model_path): - from muse.readers.csv import read_macro_drivers + from muse.readers.csv.correlation_consumption import read_macro_drivers data = read_macro_drivers( default_correlation_model_path / "residential_presets" / "Macrodrivers.csv" @@ -646,7 +646,7 @@ def test_read_macro_drivers(default_correlation_model_path): def test_read_regression_parameters(default_correlation_model_path): - from muse.readers.csv import read_regression_parameters + from muse.readers.csv.correlation_consumption import read_regression_parameters data = read_regression_parameters( default_correlation_model_path diff --git a/tests/test_read_csv.py b/tests/test_read_csv.py index 014345705..83b7d89e6 100644 --- a/tests/test_read_csv.py +++ b/tests/test_read_csv.py @@ -55,12 +55,12 @@ def test_read_technodata_timeslices_csv(default_timeslice_model_path): assert set(timeslices_df.columns) == mandatory_columns | extra_columns -def test_read_initial_capacity_csv(default_model_path): +def test_read_existing_capacity_csv(default_model_path): """Test reading the initial capacity CSV file.""" - from muse.readers.csv.assets import read_initial_capacity_csv + from muse.readers.csv.assets import read_existing_capacity_csv capacity_path = default_model_path / "power" / "ExistingCapacity.csv" - capacity_df = read_initial_capacity_csv(capacity_path) + capacity_df = read_existing_capacity_csv(capacity_path) mandatory_columns = { "region", "technology", @@ -93,7 +93,7 @@ def test_read_global_commodities_csv(default_model_path): def test_read_timeslice_shares_csv(default_correlation_model_path): """Test reading the timeslice shares CSV file.""" - from muse.readers.csv.regression import read_timeslice_shares_csv + from muse.readers.csv.correlation_consumption import read_timeslice_shares_csv shares_path = ( default_correlation_model_path @@ -115,12 +115,12 @@ def test_read_timeslice_shares_csv(default_correlation_model_path): assert set(shares_df.columns) == mandatory_columns | extra_columns -def test_read_agent_parameters_csv(default_model_path): +def test_read_agents_csv(default_model_path): """Test reading the agent parameters CSV file.""" - from muse.readers.csv.agents import read_agent_parameters_csv + from muse.readers.csv.agents import read_agents_csv agents_path = default_model_path / "Agents.csv" - agents_df = read_agent_parameters_csv(agents_path) + agents_df = read_agents_csv(agents_path) mandatory_columns = { "search_rule", "quantity", @@ -140,7 +140,7 @@ def test_read_agent_parameters_csv(default_model_path): def test_read_macro_drivers_csv(default_correlation_model_path): """Test reading the macro drivers CSV file.""" - from muse.readers.csv.regression import read_macro_drivers_csv + from muse.readers.csv.correlation_consumption import read_macro_drivers_csv macro_path = ( default_correlation_model_path / "residential_presets" / "Macrodrivers.csv" @@ -180,7 +180,7 @@ def test_read_projections_csv(default_model_path): def test_read_regression_parameters_csv(default_correlation_model_path): """Test reading the regression parameters CSV file.""" - from muse.readers.csv.regression import read_regression_parameters_csv + from muse.readers.csv.correlation_consumption import read_regression_parameters_csv regression_path = ( default_correlation_model_path @@ -224,7 +224,7 @@ def test_read_presets_csv(default_model_path): def test_read_existing_trade_csv(trade_model_path): """Test reading the existing trade CSV file.""" - from muse.readers.csv.trade import read_existing_trade_csv + from muse.readers.csv.trade_assets import read_existing_trade_csv trade_path = trade_model_path / "power" / "ExistingTrade.csv" trade_df = read_existing_trade_csv(trade_path) @@ -239,7 +239,7 @@ def test_read_existing_trade_csv(trade_model_path): def test_read_trade_technodata(trade_model_path): """Test reading the trade technodata CSV file.""" - from muse.readers.csv.trade import read_trade_technodata_csv + from muse.readers.csv.trade_technodata import read_trade_technodata_csv trade_path = trade_model_path / "power" / "TradeTechnodata.csv" trade_df = read_trade_technodata_csv(trade_path) diff --git a/tests/test_subsector.py b/tests/test_subsector.py index 29affb24b..93eae1fb3 100644 --- a/tests/test_subsector.py +++ b/tests/test_subsector.py @@ -8,7 +8,7 @@ from muse import examples from muse import investments as iv from muse.agents.factories import create_agent -from muse.readers import read_agent_parameters, read_initial_capacity +from muse.readers import read_agents, read_assets from muse.readers.toml import read_settings from muse.sectors.subsector import Subsector, aggregate_enduses @@ -41,8 +41,8 @@ def agent_params(model, tmp_path, technologies): """Common agent parameters setup.""" examples.copy_model(model, tmp_path) path = tmp_path / "model" / "Agents.csv" - params = read_agent_parameters(path) - capa = read_initial_capacity(path.with_name("residential") / "ExistingCapacity.csv") + params = read_agents(path) + capa = read_assets(path.with_name("residential") / "ExistingCapacity.csv") for param in params: param.update(