diff --git a/src/muse/__init__.py b/src/muse/__init__.py index cac496a06..a55470a7b 100644 --- a/src/muse/__init__.py +++ b/src/muse/__init__.py @@ -96,14 +96,10 @@ def add_file_logger() -> None: "investments", "objectives", "outputs", - "read_agent_parameters", + "read_agents", + "read_existing_capacity", "read_global_commodities", - "read_initial_capacity", - "read_io_technodata", - "read_macro_drivers", "read_settings", - "read_technodictionary", "read_technologies", - "read_timeslice_shares", "sectors", ] diff --git a/src/muse/agents/factories.py b/src/muse/agents/factories.py index a9137a5ad..c6479874c 100644 --- a/src/muse/agents/factories.py +++ b/src/muse/agents/factories.py @@ -167,7 +167,7 @@ def create_agent(agent_type: str, **kwargs) -> Agent: def agents_factory( - params_or_path: str | Path | list, + path: Path, capacity: xr.DataArray, technologies: xr.Dataset, regions: Sequence[str] | None = None, @@ -178,12 +178,9 @@ def agents_factory( from copy import deepcopy from logging import getLogger - from muse.readers import read_agent_parameters + from muse.readers import read_agents - if isinstance(params_or_path, (str, Path)): - params = read_agent_parameters(params_or_path) - else: - params = params_or_path + params = read_agents(path) assert isinstance(capacity, xr.DataArray) if year is None: year = int(capacity.year.min()) diff --git a/src/muse/commodities.py b/src/muse/commodities.py index 517fd81e6..3874475b8 100644 --- a/src/muse/commodities.py +++ b/src/muse/commodities.py @@ -14,10 +14,10 @@ def setup_module(commodities_path: Path): """Sets up global commodities.""" - from muse.readers.csv import read_global_commodities + from muse.readers.csv import read_commodities global COMMODITIES - COMMODITIES = read_global_commodities(commodities_path) + COMMODITIES = read_commodities(commodities_path) class CommodityUsage(IntFlag): diff --git a/src/muse/examples.py b/src/muse/examples.py index 189a82c85..7cd2a426b 100644 --- a/src/muse/examples.py +++ b/src/muse/examples.py @@ -144,7 +144,7 @@ def technodata(sector: str, model: str = "default") -> xr.Dataset: """Technology for a sector of a given example model.""" from tempfile import TemporaryDirectory - from muse.readers.toml import read_settings, read_technodata + from muse.readers.toml import read_sector_technodata, read_settings sector = sector.lower() allowed = {"residential", "power", "gas", "preset"} @@ -155,7 +155,7 @@ def technodata(sector: str, model: str = "default") -> xr.Dataset: with TemporaryDirectory() as tmpdir: path = copy_model(model, tmpdir) settings = read_settings(path / "settings.toml") - return read_technodata(settings, sector) + return read_sector_technodata(settings, sector) def search_space(sector: str, model: str = "default") -> xr.DataArray: diff --git a/src/muse/readers/__init__.py b/src/muse/readers/__init__.py index 16900043d..9c7edf403 100644 --- a/src/muse/readers/__init__.py +++ b/src/muse/readers/__init__.py @@ -1,8 +1,9 @@ """Aggregates methods to read data from file.""" from muse.defaults import DATA_DIRECTORY -from muse.readers.csv import * # noqa: F403 -from muse.readers.toml import read_settings # noqa: F401 + +from .csv import * # noqa: F403 +from .toml import read_settings # noqa: F401 DEFAULT_SETTINGS_PATH = DATA_DIRECTORY / "default_settings.toml" """Default settings path.""" diff --git a/src/muse/readers/csv.py b/src/muse/readers/csv.py deleted file mode 100644 index 9a5018f63..000000000 --- a/src/muse/readers/csv.py +++ /dev/null @@ -1,1475 +0,0 @@ -"""Ensemble of functions to read MUSE data. - -In general, there are three functions per input file: -`read_x`: This is the overall function that is called to read the data. It takes a - `Path` as input, and returns the relevant data structure (usually an xarray). The - process is generally broken down into two functions that are called by `read_x`: - -`read_x_csv`: This takes a path to a csv file as input and returns a pandas dataframe. - There are some consistency checks, such as checking data types and columns. There - is also some minor processing at this stage, such as standardising column names, - but no structural changes to the data. The general rule is that anything returned - by this function should still be valid as an input file if saved to csv. -`process_x`: This is where more major processing and reformatting of the data is done. - It takes the dataframe from `read_x_csv` and returns the final data structure - (usually an xarray). There are also some more checks (e.g. checking for nan - values). - -Most of the processing is shared by a few helper functions: -- read_csv: reads a csv file and returns a dataframe -- standardize_dataframe: standardizes the dataframe to a common format -- create_multiindex: creates a multiindex from a dataframe -- create_xarray_dataset: creates an xarray dataset from a dataframe - -A few other helpers perform common operations on xarrays: -- create_assets: creates assets from technologies -- check_commodities: checks commodities and fills missing values - -""" - -from __future__ import annotations - -__all__ = [ - "read_agent_parameters", - "read_attribute_table", - "read_csv", - "read_existing_trade", - "read_global_commodities", - "read_initial_capacity", - "read_initial_market", - "read_io_technodata", - "read_macro_drivers", - "read_presets", - "read_regression_parameters", - "read_technodata_timeslices", - "read_technodictionary", - "read_technologies", - "read_timeslice_shares", - "read_trade_technodata", -] - -from logging import getLogger -from pathlib import Path - -import pandas as pd -import xarray as xr - -from muse.utilities import camel_to_snake - -# Global mapping of column names to their standardized versions -# This is for backwards compatibility with old file formats -COLUMN_RENAMES = { - "process_name": "technology", - "process": "technology", - "sector_name": "sector", - "region_name": "region", - "time": "year", - "commodity_name": "commodity", - "comm_type": "commodity_type", - "commodity_price": "prices", - "units_commodity_price": "units_prices", - "enduse": "end_use", - "sn": "timeslice", - "commodity_emission_factor_CO2": "emmission_factor", - "utilisation_factor": "utilization_factor", - "objsort": "obj_sort", - "objsort1": "obj_sort1", - "objsort2": "obj_sort2", - "objsort3": "obj_sort3", - "time_slice": "timeslice", - "price": "prices", -} - -# Columns who's values should be converted from camelCase to snake_case -CAMEL_TO_SNAKE_COLUMNS = [ - "tech_type", - "commodity", - "commodity_type", - "agent_share", - "attribute", - "sector", - "region", - "parameter", -] - -# Global mapping of column names to their expected types -COLUMN_TYPES = { - "year": int, - "region": str, - "technology": str, - "commodity": str, - "sector": str, - "attribute": str, - "variable": str, - "timeslice": int, # For tables that require int timeslice instead of month etc. - "name": str, - "commodity_type": str, - "tech_type": str, - "type": str, - "function_type": str, - "level": str, - "search_rule": str, - "decision_method": str, - "quantity": float, - "share": float, - "coeff": str, - "value": float, - "utilization_factor": float, - "minimum_service_factor": float, - "maturity_threshold": float, - "spend_limit": float, - "prices": float, - "emmission_factor": float, -} - -DEFAULTS = { - "cap_par": 0, - "cap_exp": 1, - "fix_par": 0, - "fix_exp": 1, - "var_par": 0, - "var_exp": 1, - "interest_rate": 0, - "utilization_factor": 1, - "minimum_service_factor": 0, - "search_rule": "all", - "decision_method": "single", -} - - -def standardize_columns(data: pd.DataFrame) -> pd.DataFrame: - """Standardizes column names in a DataFrame. - - This function: - 1. Converts column names to snake_case - 2. Applies the global COLUMN_RENAMES mapping - 3. Preserves any columns not in the mapping - - Args: - data: DataFrame to standardize - - Returns: - DataFrame with standardized column names - """ - # Drop index column if present - if data.columns[0] == "" or data.columns[0].startswith("Unnamed"): - data = data.iloc[:, 1:] - - # Convert columns to snake_case - data = data.rename(columns=camel_to_snake) - - # Then apply global mapping - data = data.rename(columns=COLUMN_RENAMES) - - # Make sure there are no duplicate columns - if len(data.columns) != len(set(data.columns)): - raise ValueError(f"Duplicate columns in {data.columns}") - - return data - - -def create_multiindex( - data: pd.DataFrame, - index_columns: list[str], - index_names: list[str], - drop_columns: bool = True, -) -> pd.DataFrame: - """Creates a MultiIndex from specified columns. - - Args: - data: DataFrame to create index from - index_columns: List of column names to use for index - index_names: List of names for the index levels - drop_columns: Whether to drop the original columns - - Returns: - DataFrame with new MultiIndex - """ - index = pd.MultiIndex.from_arrays( - [data[col] for col in index_columns], names=index_names - ) - result = data.copy() - result.index = index - if drop_columns: - result = result.drop(columns=index_columns) - return result - - -def create_xarray_dataset( - data: pd.DataFrame, - disallow_nan: bool = True, -) -> xr.Dataset: - """Creates an xarray Dataset from a DataFrame with standardized options. - - Args: - data: DataFrame to convert - disallow_nan: Whether to raise an error if NaN values are found - - Returns: - xarray Dataset - """ - result = xr.Dataset.from_dataframe(data) - if disallow_nan: - nan_coords = get_nan_coordinates(result) - if nan_coords: - raise ValueError(f"Missing data for coordinates: {nan_coords}") - - if "year" in result.coords: - result = result.assign_coords(year=result.year.astype(int)) - result = result.sortby("year") - assert len(set(result.year.values)) == result.year.data.size # no duplicates - - return result - - -def get_nan_coordinates(dataset: xr.Dataset) -> list[tuple]: - """Get coordinates of a Dataset where any data variable has NaN values.""" - any_nan = sum(var.isnull() for var in dataset.data_vars.values()) - if any_nan.any(): - return any_nan.where(any_nan, drop=True).to_dataframe(name="").index.to_list() - return [] - - -def convert_column_types(data: pd.DataFrame) -> pd.DataFrame: - """Converts DataFrame columns to their expected types. - - Args: - data: DataFrame to convert - - Returns: - DataFrame with converted column types - """ - result = data.copy() - for column, expected_type in COLUMN_TYPES.items(): - if column in result.columns: - try: - if expected_type is int: - result[column] = pd.to_numeric(result[column], downcast="integer") - elif expected_type is float: - result[column] = pd.to_numeric(result[column]).astype(float) - elif expected_type is str: - result[column] = result[column].astype(str) - except (ValueError, TypeError) as e: - raise ValueError( - f"Could not convert column '{column}' to {expected_type.__name__}: {e}" # noqa: E501 - ) - return result - - -def standardize_dataframe( - data: pd.DataFrame, - required_columns: list[str] | None = None, - exclude_extra_columns: bool = False, -) -> pd.DataFrame: - """Standardizes a DataFrame to a common format. - - Args: - data: DataFrame to standardize - required_columns: List of column names that must be present (optional) - exclude_extra_columns: If True, exclude any columns not in required_columns list - (optional). This can be important if extra columns can mess up the resulting - xarray object. - - Returns: - DataFrame containing the standardized data - """ - if required_columns is None: - required_columns = [] - - # Standardize column names - data = standardize_columns(data) - - # Convert specified column values from camelCase to snake_case - for col in CAMEL_TO_SNAKE_COLUMNS: - if col in data.columns: - data[col] = data[col].apply(camel_to_snake) - - # Fill missing values with defaults - data = data.fillna(DEFAULTS) - for col, default in DEFAULTS.items(): - if col not in data.columns and col in required_columns: - data[col] = default - - # Check/convert data types - data = convert_column_types(data) - - # Validate required columns if provided - if required_columns: - missing_columns = [col for col in required_columns if col not in data.columns] - if missing_columns: - raise ValueError(f"Missing required columns: {missing_columns}") - - # Exclude extra columns if requested - if exclude_extra_columns: - data = data[list(required_columns)] - - return data - - -def read_csv( - path: Path, - float_precision: str = "high", - required_columns: list[str] | None = None, - exclude_extra_columns: bool = False, - msg: str | None = None, -) -> pd.DataFrame: - """Reads and standardizes a CSV file into a DataFrame. - - Args: - path: Path to the CSV file - float_precision: Precision to use when reading floats - required_columns: List of column names that must be present (optional) - exclude_extra_columns: If True, exclude any columns not in required_columns list - (optional). This can be important if extra columns can mess up the resulting - xarray object. - msg: Message to log (optional) - - Returns: - DataFrame containing the standardized data - """ - # Log message - if msg: - getLogger(__name__).info(msg) - - # Check if file exists - if not path.is_file(): - raise OSError(f"{path} does not exist.") - - # Check if there's a units row (in which case we need to skip it) - with open(path) as f: - next(f) # Skip header row - first_data_row = f.readline().strip() - skiprows = [1] if first_data_row.startswith("Unit") else None - - # Read the file - data = pd.read_csv( - path, - float_precision=float_precision, - low_memory=False, - skiprows=skiprows, - ) - - # Standardize the DataFrame - return standardize_dataframe( - data, - required_columns=required_columns, - exclude_extra_columns=exclude_extra_columns, - ) - - -def check_commodities( - data: xr.Dataset | xr.DataArray, fill_missing: bool = True, fill_value: float = 0 -) -> xr.Dataset | xr.DataArray: - """Validates and optionally fills missing commodities in data.""" - from muse.commodities import COMMODITIES - - # Make sure there are no commodities in data but not in global commodities - extra_commodities = [ - c for c in data.commodity.values if c not in COMMODITIES.commodity.values - ] - if extra_commodities: - raise ValueError( - "The following commodities were not found in global commodities file: " - f"{extra_commodities}" - ) - - # Add any missing commodities with fill_value - if fill_missing: - data = data.reindex( - commodity=COMMODITIES.commodity.values, fill_value=fill_value - ) - return data - - -def create_assets(data: xr.DataArray | xr.Dataset) -> xr.DataArray | xr.Dataset: - """Creates assets from technology data.""" - # Rename technology to asset - result = data.drop_vars("technology").rename(technology="asset") - result["technology"] = "asset", data.technology.values - - # Add installed year - result["installed"] = ("asset", [int(result.year.min())] * len(result.technology)) - return result - - -def read_technodictionary(path: Path) -> xr.Dataset: - """Reads and processes technodictionary data from a CSV file.""" - df = read_technodictionary_csv(path) - return process_technodictionary(df) - - -def read_technodictionary_csv(path: Path) -> pd.DataFrame: - """Reads and formats technodata into a DataFrame.""" - required_columns = { - "cap_exp", - "region", - "var_par", - "fix_exp", - "interest_rate", - "utilization_factor", - "minimum_service_factor", - "year", - "cap_par", - "var_exp", - "technology", - "technical_life", - "fix_par", - } - data = read_csv( - path, - required_columns=required_columns, - msg=f"Reading technodictionary from {path}.", - ) - - # Check for deprecated columns - if "fuel" in data.columns: - msg = ( - f"The 'fuel' column in {path} has been deprecated. " - "This information is now determined from CommIn files. " - "Please remove this column from your Technodata files." - ) - getLogger(__name__).warning(msg) - if "end_use" in data.columns: - msg = ( - f"The 'end_use' column in {path} has been deprecated. " - "This information is now determined from CommOut files. " - "Please remove this column from your Technodata files." - ) - getLogger(__name__).warning(msg) - if "scaling_size" in data.columns: - msg = ( - f"The 'scaling_size' column in {path} has been deprecated. " - "Please remove this column from your Technodata files." - ) - getLogger(__name__).warning(msg) - - return data - - -def process_technodictionary(data: pd.DataFrame) -> xr.Dataset: - """Processes technodictionary DataFrame into an xarray Dataset.""" - # Create multiindex for technology and region - data = create_multiindex( - data, - index_columns=["technology", "region", "year"], - index_names=["technology", "region", "year"], - drop_columns=True, - ) - - # Create dataset - result = create_xarray_dataset(data) - - # Handle tech_type if present - if "type" in result.variables: - result["tech_type"] = result.type.isel(region=0, year=0) - - return result - - -def read_technodata_timeslices(path: Path) -> xr.Dataset: - """Reads and processes technodata timeslices from a CSV file.""" - df = read_technodata_timeslices_csv(path) - return process_technodata_timeslices(df) - - -def read_technodata_timeslices_csv(path: Path) -> pd.DataFrame: - """Reads and formats technodata timeslices into a DataFrame.""" - from muse.timeslices import TIMESLICE - - timeslice_columns = set(TIMESLICE.coords["timeslice"].indexes["timeslice"].names) - required_columns = { - "utilization_factor", - "technology", - "minimum_service_factor", - "region", - "year", - } | timeslice_columns - return read_csv( - path, - required_columns=required_columns, - exclude_extra_columns=True, - msg=f"Reading technodata timeslices from {path}.", - ) - - -def process_technodata_timeslices(data: pd.DataFrame) -> xr.Dataset: - """Processes technodata timeslices DataFrame into an xarray Dataset.""" - from muse.timeslices import TIMESLICE, sort_timeslices - - # Create multiindex for all columns except factor columns - factor_columns = ["utilization_factor", "minimum_service_factor", "obj_sort"] - index_columns = [col for col in data.columns if col not in factor_columns] - data = create_multiindex( - data, - index_columns=index_columns, - index_names=index_columns, - drop_columns=True, - ) - - # Create dataset - result = create_xarray_dataset(data) - - # Stack timeslice levels (month, day, hour) into a single timeslice dimension - timeslice_levels = TIMESLICE.coords["timeslice"].indexes["timeslice"].names - if all(level in result.dims for level in timeslice_levels): - result = result.stack(timeslice=timeslice_levels) - return sort_timeslices(result) - - -def read_io_technodata(path: Path) -> xr.Dataset: - """Reads and processes input/output technodata from a CSV file.""" - df = read_io_technodata_csv(path) - return process_io_technodata(df) - - -def read_io_technodata_csv(path: Path) -> pd.DataFrame: - """Reads process inputs or outputs into a DataFrame.""" - data = read_csv( - path, - required_columns=["technology", "region", "year"], - msg=f"Reading IO technodata from {path}.", - ) - - # Unspecified Level values default to "fixed" - if "level" in data.columns: - data["level"] = data["level"].fillna("fixed") - else: - # Particularly relevant to outputs files where the Level column is omitted by - # default, as only "fixed" outputs are allowed. - data["level"] = "fixed" - - return data - - -def process_io_technodata(data: pd.DataFrame) -> xr.Dataset: - """Processes IO technodata DataFrame into an xarray Dataset.""" - from muse.commodities import COMMODITIES - - # Extract commodity columns - commodities = [c for c in data.columns if c in COMMODITIES.commodity.values] - - # Convert commodity columns to long format (i.e. single "commodity" column) - data = data.melt( - id_vars=["technology", "region", "year", "level"], - value_vars=commodities, - var_name="commodity", - value_name="value", - ) - - # Pivot data to create fixed and flexible columns - data = data.pivot( - index=["technology", "region", "year", "commodity"], - columns="level", - values="value", - ) - - # Create xarray dataset - result = create_xarray_dataset(data) - - # Fill in flexible data - if "flexible" in result.data_vars: - result["flexible"] = result.flexible.fillna(0) - else: - result["flexible"] = xr.zeros_like(result.fixed).rename("flexible") - - # Check commodities - result = check_commodities(result, fill_missing=True, fill_value=0) - return result - - -def read_technologies( - technodata_path: Path, - comm_out_path: Path, - comm_in_path: Path, - time_framework: list[int], - interpolation_mode: str = "linear", - technodata_timeslices_path: Path | None = None, -) -> xr.Dataset: - """Reads and processes technology data from multiple CSV files. - - Will also interpolate data to the time framework if provided. - - Args: - technodata_path: path to the technodata file - comm_out_path: path to the comm_out file - comm_in_path: path to the comm_in file - time_framework: list of years to interpolate data to - interpolation_mode: Interpolation mode to use - technodata_timeslices_path: path to the technodata_timeslices file - - Returns: - xr.Dataset: Dataset containing the processed technology data. Any fields - that differ by year will contain a "year" dimension interpolated to the - time framework. Other fields will not have a "year" dimension. - """ - # Read all data - technodata = read_technodictionary(technodata_path) - comm_out = read_io_technodata(comm_out_path) - comm_in = read_io_technodata(comm_in_path) - technodata_timeslices = ( - read_technodata_timeslices(technodata_timeslices_path) - if technodata_timeslices_path - else None - ) - - # Assemble xarray Dataset - return process_technologies( - technodata, - comm_out, - comm_in, - time_framework, - interpolation_mode, - technodata_timeslices, - ) - - -def process_technologies( - technodata: xr.Dataset, - comm_out: xr.Dataset, - comm_in: xr.Dataset, - time_framework: list[int], - interpolation_mode: str = "linear", - technodata_timeslices: xr.Dataset | None = None, -) -> xr.Dataset: - """Processes technology data DataFrames into an xarray Dataset.""" - from muse.commodities import COMMODITIES, CommodityUsage - from muse.timeslices import drop_timeslice - from muse.utilities import interpolate_technodata - - # Process inputs/outputs - ins = comm_in.rename(flexible="flexible_inputs", fixed="fixed_inputs") - outs = comm_out.rename(flexible="flexible_outputs", fixed="fixed_outputs") - - # Legacy: Remove flexible outputs - if not (outs["flexible_outputs"] == 0).all(): - raise ValueError( - "'flexible' outputs are not permitted. All outputs must be 'fixed'" - ) - outs = outs.drop_vars("flexible_outputs") - - # Collect all years from the time framework and data files - time_framework = list( - set(time_framework).union( - technodata.year.values.tolist(), - ins.year.values.tolist(), - outs.year.values.tolist(), - technodata_timeslices.year.values.tolist() if technodata_timeslices else [], - ) - ) - - # Interpolate data to match the time framework - technodata = interpolate_technodata(technodata, time_framework, interpolation_mode) - outs = interpolate_technodata(outs, time_framework, interpolation_mode) - ins = interpolate_technodata(ins, time_framework, interpolation_mode) - if technodata_timeslices: - technodata_timeslices = interpolate_technodata( - technodata_timeslices, time_framework, interpolation_mode - ) - - # Merge inputs/outputs with technodata - technodata = technodata.merge(outs).merge(ins) - - # Merge technodata_timeslices if provided. This will prioritise values defined in - # technodata_timeslices, and fallback to the non-timesliced technodata for any - # values that are not defined in technodata_timeslices. - if technodata_timeslices: - technodata["utilization_factor"] = ( - technodata_timeslices.utilization_factor.combine_first( - technodata.utilization_factor - ) - ) - technodata["minimum_service_factor"] = drop_timeslice( - technodata_timeslices.minimum_service_factor.combine_first( - technodata.minimum_service_factor - ) - ) - - # Check commodities - technodata = check_commodities(technodata, fill_missing=False) - - # Add info about commodities - technodata = technodata.merge(COMMODITIES.sel(commodity=technodata.commodity)) - - # Add commodity usage flags - technodata["comm_usage"] = ( - "commodity", - CommodityUsage.from_technologies(technodata).values, - ) - technodata = technodata.drop_vars("commodity_type") - - # Check utilization and minimum service factors - check_utilization_and_minimum_service_factors(technodata) - - return technodata - - -def read_initial_capacity(path: Path) -> xr.DataArray: - """Reads and processes initial capacity data from a CSV file.""" - df = read_initial_capacity_csv(path) - return process_initial_capacity(df) - - -def read_initial_capacity_csv(path: Path) -> pd.DataFrame: - """Reads and formats data about initial capacity into a DataFrame.""" - required_columns = { - "region", - "technology", - } - return read_csv( - path, - required_columns=required_columns, - msg=f"Reading initial capacity from {path}.", - ) - - -def process_initial_capacity(data: pd.DataFrame) -> xr.DataArray: - """Processes initial capacity DataFrame into an xarray DataArray.""" - # Drop unit column if present - if "unit" in data.columns: - data = data.drop(columns=["unit"]) - - # Select year columns - year_columns = [col for col in data.columns if col.isdigit()] - - # Convert year columns to long format (i.e. single "year" column) - data = data.melt( - id_vars=["technology", "region"], - value_vars=year_columns, - var_name="year", - value_name="value", - ) - - # Create multiindex for region, technology, and year - data = create_multiindex( - data, - index_columns=["technology", "region", "year"], - index_names=["technology", "region", "year"], - drop_columns=True, - ) - - # Create Dataarray - result = create_xarray_dataset(data).value.astype(float) - - # Create assets - result = create_assets(result) - return result - - -def read_global_commodities(path: Path) -> xr.Dataset: - """Reads and processes global commodities data from a CSV file.""" - df = read_global_commodities_csv(path) - return process_global_commodities(df) - - -def read_global_commodities_csv(path: Path) -> pd.DataFrame: - """Reads commodities information from input into a DataFrame.""" - # Due to legacy reasons, users can supply both Commodity and CommodityName columns - # In this case, we need to remove the Commodity column to avoid conflicts - # This is fine because Commodity just contains a long description that isn't needed - getLogger(__name__).info(f"Reading global commodities from {path}.") - df = pd.read_csv(path) - df = df.rename(columns=camel_to_snake) - if "commodity" in df.columns and "commodity_name" in df.columns: - df = df.drop(columns=["commodity"]) - - required_columns = { - "commodity", - "commodity_type", - } - data = standardize_dataframe( - df, - required_columns=required_columns, - ) - - # Raise warning if units are not defined - if "unit" not in data.columns: - msg = ( - "No units defined for commodities. Please define units for all commodities " - "in the global commodities file." - ) - getLogger(__name__).warning(msg) - - return data - - -def process_global_commodities(data: pd.DataFrame) -> xr.Dataset: - """Processes global commodities DataFrame into an xarray Dataset.""" - # Drop description column if present. It's useful to include in the file, but we - # don't need it for the simulation. - if "description" in data.columns: - data = data.drop(columns=["description"]) - - data.index = [u for u in data.commodity] - data = data.drop("commodity", axis=1) - data.index.name = "commodity" - return create_xarray_dataset(data) - - -def read_agent_parameters(path: Path) -> pd.DataFrame: - """Reads and processes agent parameters from a CSV file.""" - df = read_agent_parameters_csv(path) - return process_agent_parameters(df) - - -def read_agent_parameters_csv(path: Path) -> pd.DataFrame: - """Reads standard MUSE agent-declaration csv-files into a DataFrame.""" - required_columns = { - "search_rule", - "quantity", - "region", - "type", - "name", - "agent_share", - "decision_method", - } - data = read_csv( - path, - required_columns=required_columns, - msg=f"Reading agent parameters from {path}.", - ) - - # Check for deprecated retrofit agents - if "type" in data.columns: - retrofit_agents = data[data.type.str.lower().isin(["retrofit", "retro"])] - if not retrofit_agents.empty: - msg = ( - "Retrofit agents will be deprecated in a future release. " - "Please modify your model to use only agents of the 'New' type." - ) - getLogger(__name__).warning(msg) - - # Legacy: drop AgentNumber column - if "agent_number" in data.columns: - data = data.drop(["agent_number"], axis=1) - - # Check consistency of objectives data columns - objectives = [col for col in data.columns if col.startswith("objective")] - floats = [col for col in data.columns if col.startswith("obj_data")] - sorting = [col for col in data.columns if col.startswith("obj_sort")] - - if len(objectives) != len(floats) or len(objectives) != len(sorting): - raise ValueError( - "Agent objective, obj_data, and obj_sort columns are inconsistent in " - f"{path}" - ) - - return data - - -def process_agent_parameters(data: pd.DataFrame) -> list[dict]: - """Processes agent parameters DataFrame into a list of agent dictionaries.""" - result = [] - for _, row in data.iterrows(): - # Get objectives data - objectives = ( - row[[i.startswith("objective") for i in row.index]].dropna().to_list() - ) - sorting = row[[i.startswith("obj_sort") for i in row.index]].dropna().to_list() - floats = row[[i.startswith("obj_data") for i in row.index]].dropna().to_list() - - # Create decision parameters - decision_params = list(zip(objectives, sorting, floats)) - - agent_type = { - "new": "newcapa", - "newcapa": "newcapa", - "retrofit": "retrofit", - "retro": "retrofit", - "agent": "agent", - "default": "agent", - }[getattr(row, "type", "agent").lower()] - - # Create agent data dictionary - data = { - "name": row["name"], - "region": row.region, - "objectives": objectives, - "search_rules": row.search_rule, - "decision": {"name": row.decision_method, "parameters": decision_params}, - "agent_type": agent_type, - "quantity": row.quantity, - "share": row.agent_share, - } - - # Add optional parameters - if hasattr(row, "maturity_threshold"): - data["maturity_threshold"] = row.maturity_threshold - if hasattr(row, "spend_limit"): - data["spend_limit"] = row.spend_limit - - # Add agent data to result - result.append(data) - - return result - - -def read_initial_market( - projections_path: Path, - base_year_import_path: Path | None = None, - base_year_export_path: Path | None = None, - currency: str | None = None, -) -> xr.Dataset: - """Reads and processes initial market data. - - Args: - projections_path: path to the projections file - base_year_import_path: path to the base year import file (optional) - base_year_export_path: path to the base year export file (optional) - currency: currency string (e.g. "USD") - - Returns: - xr.Dataset: Dataset containing initial market data. - """ - # Read projections - projections_df = read_projections_csv(projections_path) - - # Read base year export (optional) - if base_year_export_path: - export_df = read_csv( - base_year_export_path, - msg=f"Reading base year export from {base_year_export_path}.", - ) - else: - export_df = None - - # Read base year import (optional) - if base_year_import_path: - import_df = read_csv( - base_year_import_path, - msg=f"Reading base year import from {base_year_import_path}.", - ) - else: - import_df = None - - # Assemble into xarray Dataset - result = process_initial_market(projections_df, import_df, export_df, currency) - return result - - -def read_projections_csv(path: Path) -> pd.DataFrame: - """Reads projections data from a CSV file.""" - required_columns = { - "region", - "attribute", - "year", - } - projections_df = read_csv( - path, required_columns=required_columns, msg=f"Reading projections from {path}." - ) - return projections_df - - -def process_initial_market( - projections_df: pd.DataFrame, - import_df: pd.DataFrame | None, - export_df: pd.DataFrame | None, - currency: str | None = None, -) -> xr.Dataset: - """Process market data DataFrames into an xarray Dataset. - - Args: - projections_df: DataFrame containing projections data - import_df: Optional DataFrame containing import data - export_df: Optional DataFrame containing export data - currency: Currency string (e.g. "USD") - """ - from muse.commodities import COMMODITIES - from muse.timeslices import broadcast_timeslice, distribute_timeslice - - # Process projections - projections = process_attribute_table(projections_df).commodity_price.astype( - "float64" - ) - - # Process optional trade data - if export_df is not None: - base_year_export = process_attribute_table(export_df).exports.astype("float64") - else: - base_year_export = xr.zeros_like(projections) - - if import_df is not None: - base_year_import = process_attribute_table(import_df).imports.astype("float64") - else: - base_year_import = xr.zeros_like(projections) - - # Distribute data over timeslices - projections = broadcast_timeslice(projections, level=None) - base_year_export = distribute_timeslice(base_year_export, level=None) - base_year_import = distribute_timeslice(base_year_import, level=None) - - # Assemble into xarray - result = xr.Dataset( - { - "prices": projections, - "exports": base_year_export, - "imports": base_year_import, - "static_trade": base_year_import - base_year_export, - } - ) - - # Check commodities - result = check_commodities(result, fill_missing=True, fill_value=0) - - # Add units_prices coordinate - # Only added if the currency is specified and commodity units are defined - if currency and "unit" in COMMODITIES.data_vars: - units_prices = [ - f"{currency}/{COMMODITIES.sel(commodity=c).unit.item()}" - for c in result.commodity.values - ] - result = result.assign_coords(units_prices=("commodity", units_prices)) - - return result - - -def read_attribute_table(path: Path) -> xr.Dataset: - """Reads and processes attribute table data from a CSV file.""" - df = read_attribute_table_csv(path) - return process_attribute_table(df) - - -def read_attribute_table_csv(path: Path) -> pd.DataFrame: - """Read a standard MUSE csv file for price projections into a DataFrame.""" - table = read_csv( - path, - required_columns=["region", "attribute", "year"], - msg=f"Reading attribute table from {path}.", - ) - return table - - -def process_attribute_table(data: pd.DataFrame) -> xr.Dataset: - """Process attribute table DataFrame into an xarray Dataset.""" - # Extract commodity columns - commodities = [ - col for col in data.columns if col not in ["region", "year", "attribute"] - ] - - # Convert commodity columns to long format (i.e. single "commodity" column) - data = data.melt( - id_vars=["region", "year", "attribute"], - value_vars=commodities, - var_name="commodity", - value_name="value", - ) - - # Pivot data over attributes - data = data.pivot( - index=["region", "year", "commodity"], - columns="attribute", - values="value", - ) - - # Create DataSet - result = create_xarray_dataset(data) - return result - - -def read_presets(presets_paths: Path) -> xr.Dataset: - """Reads and processes preset data from multiple CSV files. - - Accepts a path pattern for presets files, e.g. `Path("path/to/*Consumption.csv")`. - The file name of each file must contain a year (e.g. "2020Consumption.csv"). - """ - from glob import glob - from re import match - - # Find all files matching the path pattern - allfiles = [Path(p) for p in glob(str(presets_paths))] - if len(allfiles) == 0: - raise OSError(f"No files found with paths {presets_paths}") - - # Read all files - datas: dict[int, pd.DataFrame] = {} - for path in allfiles: - # Extract year from filename - reyear = match(r"\S*.(\d{4})\S*\.csv", path.name) - if reyear is None: - raise OSError(f"Unexpected filename {path.name}") - year = int(reyear.group(1)) - if year in datas: - raise OSError(f"Year f{year} was found twice") - - # Read data - data = read_presets_csv(path) - data["year"] = year - datas[year] = data - - # Process data - datas = process_presets(datas) - return datas - - -def read_presets_csv(path: Path) -> pd.DataFrame: - data = read_csv( - path, - required_columns=["region", "timeslice"], - msg=f"Reading presets from {path}.", - ) - - # Legacy: drop technology column and sum data (PR #448) - if "technology" in data.columns: - getLogger(__name__).warning( - f"The technology (or ProcessName) column in file {path} is " - "deprecated. Data has been summed across technologies, and this column " - "has been dropped." - ) - data = ( - data.drop(columns=["technology"]) - .groupby(["region", "timeslice"]) - .sum() - .reset_index() - ) - - return data - - -def process_presets(datas: dict[int, pd.DataFrame]) -> xr.Dataset: - """Processes preset DataFrames into an xarray Dataset.""" - from muse.commodities import COMMODITIES - from muse.timeslices import TIMESLICE - - # Combine into a single DataFrame - data = pd.concat(datas.values()) - - # Extract commodity columns - commodities = [c for c in data.columns if c in COMMODITIES.commodity.values] - - # Convert commodity columns to long format (i.e. single "commodity" column) - data = data.melt( - id_vars=["region", "year", "timeslice"], - value_vars=commodities, - var_name="commodity", - value_name="value", - ) - - # Create multiindex for region, year, timeslice and commodity - data = create_multiindex( - data, - index_columns=["region", "year", "timeslice", "commodity"], - index_names=["region", "year", "timeslice", "commodity"], - drop_columns=True, - ) - - # Create DataArray - result = create_xarray_dataset(data).value.astype(float) - - # Assign timeslices - result = result.assign_coords(timeslice=TIMESLICE.timeslice) - - # Check commodities - result = check_commodities(result, fill_missing=True, fill_value=0) - return result - - -def read_trade_technodata(path: Path) -> xr.Dataset: - """Reads and processes trade technodata from a CSV file.""" - df = read_trade_technodata_csv(path) - return process_trade_technodata(df) - - -def read_trade_technodata_csv(path: Path) -> pd.DataFrame: - required_columns = {"technology", "region", "parameter"} - return read_csv( - path, - required_columns=required_columns, - msg=f"Reading trade technodata from {path}.", - ) - - -def process_trade_technodata(data: pd.DataFrame) -> xr.Dataset: - # Drop unit column if present - if "unit" in data.columns: - data = data.drop(columns=["unit"]) - - # Select region columns - # TODO: this is a bit unsafe as user could supply other columns - regions = [ - col for col in data.columns if col not in ["technology", "region", "parameter"] - ] - - # Melt data over regions - data = data.melt( - id_vars=["technology", "region", "parameter"], - value_vars=regions, - var_name="dst_region", - value_name="value", - ) - - # Pivot data over parameters - data = data.pivot( - index=["technology", "region", "dst_region"], - columns="parameter", - values="value", - ) - - # Create DataSet - return create_xarray_dataset(data) - - -def read_existing_trade(path: Path) -> xr.DataArray: - """Reads and processes existing trade data from a CSV file.""" - df = read_existing_trade_csv(path) - return process_existing_trade(df) - - -def read_existing_trade_csv(path: Path) -> pd.DataFrame: - required_columns = { - "region", - "technology", - "year", - } - return read_csv( - path, - required_columns=required_columns, - msg=f"Reading existing trade from {path}.", - ) - - -def process_existing_trade(data: pd.DataFrame) -> xr.DataArray: - # Select region columns - # TODO: this is a bit unsafe as user could supply other columns - regions = [ - col for col in data.columns if col not in ["technology", "region", "year"] - ] - - # Melt data over regions - data = data.melt( - id_vars=["technology", "region", "year"], - value_vars=regions, - var_name="dst_region", - value_name="value", - ) - - # Create multiindex for region, dst_region, technology and year - data = create_multiindex( - data, - index_columns=["region", "dst_region", "technology", "year"], - index_names=["region", "dst_region", "technology", "year"], - drop_columns=True, - ) - - # Create DataArray - result = create_xarray_dataset(data).value.astype(float) - - # Create assets from technologies - result = create_assets(result) - return result - - -def read_timeslice_shares(path: Path) -> xr.DataArray: - """Reads and processes timeslice shares data from a CSV file.""" - df = read_timeslice_shares_csv(path) - return process_timeslice_shares(df) - - -def read_timeslice_shares_csv(path: Path) -> pd.DataFrame: - """Reads sliceshare information into a DataFrame.""" - data = read_csv( - path, - required_columns=["region", "timeslice"], - msg=f"Reading timeslice shares from {path}.", - ) - - return data - - -def process_timeslice_shares(data: pd.DataFrame) -> xr.DataArray: - """Processes timeslice shares DataFrame into an xarray DataArray.""" - from muse.commodities import COMMODITIES - from muse.timeslices import TIMESLICE - - # Extract commodity columns - commodities = [c for c in data.columns if c in COMMODITIES.commodity.values] - - # Convert commodity columns to long format (i.e. single "commodity" column) - data = data.melt( - id_vars=["region", "timeslice"], - value_vars=commodities, - var_name="commodity", - value_name="value", - ) - - # Create multiindex for region and timeslice - data = create_multiindex( - data, - index_columns=["region", "timeslice", "commodity"], - index_names=["region", "timeslice", "commodity"], - drop_columns=True, - ) - - # Create DataArray - result = create_xarray_dataset(data).value.astype(float) - - # Assign timeslices - result = result.assign_coords(timeslice=TIMESLICE.timeslice) - - # Check commodities - result = check_commodities(result, fill_missing=True, fill_value=0) - return result - - -def read_macro_drivers(path: Path) -> pd.DataFrame: - """Reads and processes macro drivers data from a CSV file.""" - df = read_macro_drivers_csv(path) - return process_macro_drivers(df) - - -def read_macro_drivers_csv(path: Path) -> pd.DataFrame: - """Reads a standard MUSE csv file for macro drivers into a DataFrame.""" - table = read_csv( - path, - required_columns=["region", "variable"], - msg=f"Reading macro drivers from {path}.", - ) - - # Validate required variables - required_variables = ["Population", "GDP|PPP"] - missing_variables = [ - var for var in required_variables if var not in table.variable.unique() - ] - if missing_variables: - raise ValueError(f"Missing required variables in {path}: {missing_variables}") - - return table - - -def process_macro_drivers(data: pd.DataFrame) -> xr.Dataset: - """Processes macro drivers DataFrame into an xarray Dataset.""" - # Drop unit column if present - if "unit" in data.columns: - data = data.drop(columns=["unit"]) - - # Select year columns - year_columns = [col for col in data.columns if col.isdigit()] - - # Convert year columns to long format (i.e. single "year" column) - data = data.melt( - id_vars=["variable", "region"], - value_vars=year_columns, - var_name="year", - value_name="value", - ) - - # Pivot data to create Population and GDP|PPP columns - data = data.pivot( - index=["region", "year"], - columns="variable", - values="value", - ) - - # Legacy: rename Population to population and GDP|PPP to gdp - if "Population" in data.columns: - data = data.rename(columns={"Population": "population"}) - if "GDP|PPP" in data.columns: - data = data.rename(columns={"GDP|PPP": "gdp"}) - - # Create DataSet - result = create_xarray_dataset(data) - return result - - -def read_regression_parameters(path: Path) -> xr.Dataset: - """Reads and processes regression parameters from a CSV file.""" - df = read_regression_parameters_csv(path) - return process_regression_parameters(df) - - -def read_regression_parameters_csv(path: Path) -> pd.DataFrame: - """Reads the regression parameters from a MUSE csv file into a DataFrame.""" - table = read_csv( - path, - required_columns=["region", "function_type", "coeff"], - msg=f"Reading regression parameters from {path}.", - ) - - # Legacy: warn about "sector" column - if "sector" in table.columns: - getLogger(__name__).warning( - f"The sector column (in file {path}) is deprecated. Please remove." - ) - - return table - - -def process_regression_parameters(data: pd.DataFrame) -> xr.Dataset: - """Processes regression parameters DataFrame into an xarray Dataset.""" - from muse.commodities import COMMODITIES - - # Extract commodity columns - commodities = [c for c in data.columns if c in COMMODITIES.commodity.values] - - # Melt to long format - melted = data.melt( - id_vars=["sector", "region", "function_type", "coeff"], - value_vars=commodities, - var_name="commodity", - value_name="value", - ) - - # Extract sector -> function_type mapping - sector_to_ftype = melted.drop_duplicates(["sector", "function_type"])[ - ["sector", "function_type"] - ].set_index("sector")["function_type"] - - # Pivot to create coefficient variables - pivoted = melted.pivot_table( - index=["sector", "region", "commodity"], columns="coeff", values="value" - ) - - # Create dataset and add function_type - result = create_xarray_dataset(pivoted) - result["function_type"] = xr.DataArray( - sector_to_ftype[result.sector.values].astype(object), - dims=["sector"], - name="function_type", - ) - - # Check commodities - result = check_commodities(result, fill_missing=True, fill_value=0) - return result - - -def check_utilization_and_minimum_service_factors(data: xr.Dataset) -> None: - """Check utilization and minimum service factors in an xarray dataset. - - Args: - data: xarray Dataset containing utilization_factor and minimum_service_factor - """ - if "utilization_factor" not in data.data_vars: - raise ValueError( - "A technology needs to have a utilization factor defined for every " - "timeslice." - ) - - # Check UF not all zero (sum across timeslice dimension if it exists) - if "timeslice" in data.dims: - utilization_sum = data.utilization_factor.sum(dim="timeslice") - else: - utilization_sum = data.utilization_factor - - if (utilization_sum == 0).any(): - raise ValueError( - "A technology can not have a utilization factor of 0 for every timeslice." - ) - - # Check UF in range - utilization = data.utilization_factor - if not ((utilization >= 0) & (utilization <= 1)).all(): - raise ValueError( - "Utilization factor values must all be between 0 and 1 inclusive." - ) - - # Check MSF in range - min_service_factor = data.minimum_service_factor - if not ((min_service_factor >= 0) & (min_service_factor <= 1)).all(): - raise ValueError( - "Minimum service factor values must all be between 0 and 1 inclusive." - ) - - # Check UF not below MSF - if (data.utilization_factor < data.minimum_service_factor).any(): - raise ValueError( - "Utilization factors must all be greater than or equal " - "to their corresponding minimum service factors." - ) diff --git a/src/muse/readers/csv/__init__.py b/src/muse/readers/csv/__init__.py new file mode 100644 index 000000000..a57b25377 --- /dev/null +++ b/src/muse/readers/csv/__init__.py @@ -0,0 +1,59 @@ +"""Ensemble of functions to read MUSE data. + +In general, there are three functions per input file: +`read_x`: This is the overall function that is called to read the data. It takes a + `Path` as input, and returns the relevant data structure (usually an xarray). The + process is generally broken down into two functions that are called by `read_x`: + +`read_x_csv`: This takes a path to a csv file as input and returns a pandas dataframe. + There are some consistency checks, such as checking data types and columns. There + is also some minor processing at this stage, such as standardising column names, + but no structural changes to the data. The general rule is that anything returned + by this function should still be valid as an input file if saved to csv. +`process_x`: This is where more major processing and reformatting of the data is done. + It takes the dataframe from `read_x_csv` and returns the final data structure + (usually an xarray). There are also some more checks (e.g. checking for nan + values). + +The code in this module is spread over multiple files. In general, we have one `read_x` +function per file, and as many `read_x_csv` and `process_x` functions as are required +(e.g. if a dataset is assembled from three csv files we will have three `read_x_csv` +functions, and potentially multiple `process_x` functions). + +Most of the processing is shared by a few helper functions (in `helpers.py`): +- `read_csv`: reads a csv file and returns a dataframe +- `standardize_dataframe`: standardizes the dataframe to a common format +- `create_multiindex`: creates a multiindex from a dataframe +- `create_xarray_dataset`: creates an xarray dataset from a dataframe + +A few other helpers perform common operations on xarrays: +- `create_assets`: creates assets from technologies +- `check_commodities`: checks commodities and fills missing values + +""" + +from .agents import read_agents +from .assets import read_assets +from .commodities import read_commodities +from .correlation_consumption import read_correlation_consumption +from .general import read_attribute_table +from .helpers import read_csv +from .market import read_initial_market +from .presets import read_presets +from .technologies import read_technologies +from .trade_assets import read_trade_assets +from .trade_technodata import read_trade_technodata + +__all__ = [ + "read_agents", + "read_assets", + "read_attribute_table", + "read_commodities", + "read_correlation_consumption", + "read_csv", + "read_initial_market", + "read_presets", + "read_technologies", + "read_trade_assets", + "read_trade_technodata", +] diff --git a/src/muse/readers/csv/agents.py b/src/muse/readers/csv/agents.py new file mode 100644 index 000000000..99046f237 --- /dev/null +++ b/src/muse/readers/csv/agents.py @@ -0,0 +1,110 @@ +"""Reads and processes agent parameters from a CSV file. + +This runs once per subsector, reading in a csv file and outputting a list of +dictionaries (one dictionary per agent containing the agent's parameters). +""" + +from logging import getLogger +from pathlib import Path + +import pandas as pd + +from .helpers import read_csv + + +def read_agents(path: Path) -> list[dict]: + """Reads and processes agent parameters from a CSV file.""" + df = read_agents_csv(path) + return process_agents(df) + + +def read_agents_csv(path: Path) -> pd.DataFrame: + """Reads standard MUSE agent-declaration csv-files into a DataFrame.""" + required_columns = { + "search_rule", + "quantity", + "region", + "type", + "name", + "agent_share", + "decision_method", + } + data = read_csv( + path, + required_columns=required_columns, + msg=f"Reading agent parameters from {path}.", + ) + + # Check for deprecated retrofit agents + if "type" in data.columns: + retrofit_agents = data[data.type.str.lower().isin(["retrofit", "retro"])] + if not retrofit_agents.empty: + msg = ( + "Retrofit agents will be deprecated in a future release. " + "Please modify your model to use only agents of the 'New' type." + ) + getLogger(__name__).warning(msg) + + # Legacy: drop AgentNumber column + if "agent_number" in data.columns: + data = data.drop(["agent_number"], axis=1) + + # Check consistency of objectives data columns + objectives = [col for col in data.columns if col.startswith("objective")] + floats = [col for col in data.columns if col.startswith("obj_data")] + sorting = [col for col in data.columns if col.startswith("obj_sort")] + + if len(objectives) != len(floats) or len(objectives) != len(sorting): + raise ValueError( + "Agent objective, obj_data, and obj_sort columns are inconsistent in " + f"{path}" + ) + + return data + + +def process_agents(data: pd.DataFrame) -> list[dict]: + """Processes agent parameters DataFrame into a list of agent dictionaries.""" + result = [] + for _, row in data.iterrows(): + # Get objectives data + objectives = ( + row[[i.startswith("objective") for i in row.index]].dropna().to_list() + ) + sorting = row[[i.startswith("obj_sort") for i in row.index]].dropna().to_list() + floats = row[[i.startswith("obj_data") for i in row.index]].dropna().to_list() + + # Create decision parameters + decision_params = list(zip(objectives, sorting, floats)) + + agent_type = { + "new": "newcapa", + "newcapa": "newcapa", + "retrofit": "retrofit", + "retro": "retrofit", + "agent": "agent", + "default": "agent", + }[getattr(row, "type", "agent").lower()] + + # Create agent data dictionary + data = { + "name": row["name"], + "region": row.region, + "objectives": objectives, + "search_rules": row.search_rule, + "decision": {"name": row.decision_method, "parameters": decision_params}, + "agent_type": agent_type, + "quantity": row.quantity, + "share": row.agent_share, + } + + # Add optional parameters + if hasattr(row, "maturity_threshold"): + data["maturity_threshold"] = row.maturity_threshold + if hasattr(row, "spend_limit"): + data["spend_limit"] = row.spend_limit + + # Add agent data to result + result.append(data) + + return result diff --git a/src/muse/readers/csv/assets.py b/src/muse/readers/csv/assets.py new file mode 100644 index 000000000..0fd4090dc --- /dev/null +++ b/src/muse/readers/csv/assets.py @@ -0,0 +1,63 @@ +"""Reads and processes existing capacity data from a CSV file. + +This runs once per subsector, reading in a csv file and outputting an xarray DataArray. +""" + +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .helpers import create_assets, create_multiindex, create_xarray_dataset, read_csv + + +def read_assets(path: Path) -> xr.DataArray: + """Reads and processes existing capacity data from a CSV file.""" + df = read_existing_capacity_csv(path) + return process_existing_capacity(df) + + +def read_existing_capacity_csv(path: Path) -> pd.DataFrame: + """Reads and formats data about initial capacity into a DataFrame.""" + required_columns = { + "region", + "technology", + } + return read_csv( + path, + required_columns=required_columns, + msg=f"Reading initial capacity from {path}.", + ) + + +def process_existing_capacity(data: pd.DataFrame) -> xr.DataArray: + """Processes initial capacity DataFrame into an xarray DataArray.""" + # Drop unit column if present + if "unit" in data.columns: + data = data.drop(columns=["unit"]) + + # Select year columns + year_columns = [col for col in data.columns if col.isdigit()] + + # Convert year columns to long format (i.e. single "year" column) + data = data.melt( + id_vars=["technology", "region"], + value_vars=year_columns, + var_name="year", + value_name="value", + ) + + # Create multiindex for region, technology, and year + data = create_multiindex( + data, + index_columns=["technology", "region", "year"], + index_names=["technology", "region", "year"], + drop_columns=True, + ) + + # Create Dataarray + result = create_xarray_dataset(data).value.astype(float) + + # Create assets + result = create_assets(result) + return result diff --git a/src/muse/readers/csv/commodities.py b/src/muse/readers/csv/commodities.py new file mode 100644 index 000000000..e07c4fc69 --- /dev/null +++ b/src/muse/readers/csv/commodities.py @@ -0,0 +1,68 @@ +"""Reads and processes global commodities data from a CSV file. + +This runs once per simulation, reading in a csv file and outputting an xarray Dataset. + +The CSV file will generally have the following columns: `commodity`, `commodity_type`, +unit (optional), description (optional). + +The resulting xarray Dataset will have a single `commodity` dimension, and variables +for `commodity_type` and `unit`. +""" + +from logging import getLogger +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .helpers import camel_to_snake, create_xarray_dataset, standardize_dataframe + + +def read_commodities(path: Path) -> xr.Dataset: + """Reads and processes global commodities data from a CSV file.""" + df = read_global_commodities_csv(path) + return process_global_commodities(df) + + +def read_global_commodities_csv(path: Path) -> pd.DataFrame: + """Reads commodities information from input into a DataFrame.""" + # Due to legacy reasons, users can supply both Commodity and CommodityName columns + # In this case, we need to remove the Commodity column to avoid conflicts + # This is fine because Commodity just contains a long description that isn't needed + getLogger(__name__).info(f"Reading global commodities from {path}.") + df = pd.read_csv(path) + df = df.rename(columns=camel_to_snake) + if "commodity" in df.columns and "commodity_name" in df.columns: + df = df.drop(columns=["commodity"]) + + required_columns = { + "commodity", + "commodity_type", + } + data = standardize_dataframe( + df, + required_columns=required_columns, + ) + + # Raise warning if units are not defined + if "unit" not in data.columns: + msg = ( + "No units defined for commodities. Please define units for all commodities " + "in the global commodities file." + ) + getLogger(__name__).warning(msg) + + return data + + +def process_global_commodities(data: pd.DataFrame) -> xr.Dataset: + """Processes global commodities DataFrame into an xarray Dataset.""" + # Drop description column if present. It's useful to include in the file, but we + # don't need it for the simulation. + if "description" in data.columns: + data = data.drop(columns=["description"]) + + data.index = [u for u in data.commodity] + data = data.drop("commodity", axis=1) + data.index.name = "commodity" + return create_xarray_dataset(data) diff --git a/src/muse/readers/csv/correlation_consumption.py b/src/muse/readers/csv/correlation_consumption.py new file mode 100644 index 000000000..21d5f4106 --- /dev/null +++ b/src/muse/readers/csv/correlation_consumption.py @@ -0,0 +1,237 @@ +"""Reads and processes correlation consumption data from CSV files. + +This will only run for preset sectors that have macro drivers and regression files. +Otherwise, read_presets will be used instead. +""" + +from __future__ import annotations + +from logging import getLogger +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .helpers import ( + check_commodities, + create_multiindex, + create_xarray_dataset, + read_csv, +) + + +def read_correlation_consumption( + macro_drivers_path: Path, + regression_path: Path, + timeslice_shares_path: Path | None = None, +) -> xr.Dataset: + """Read consumption data for a sector based on correlation files. + + This function calculates endogenous demand for a sector using macro drivers and + regression parameters. It applies optional filters, handles sector aggregation, + and distributes the consumption across timeslices if timeslice shares are provided. + + Args: + macro_drivers_path: Path to macro drivers file + regression_path: Path to regression parameters file + timeslice_shares_path: Path to timeslice shares file (optional) + + Returns: + xr.Dataset: Consumption data distributed across timeslices and regions + """ + from muse.regressions import endogenous_demand + from muse.timeslices import broadcast_timeslice, distribute_timeslice + + macro_drivers = read_macro_drivers(macro_drivers_path) + regression_parameters = read_regression_parameters(regression_path) + consumption = endogenous_demand( + drivers=macro_drivers, + regression_parameters=regression_parameters, + forecast=0, + ) + + # Legacy: we permit regression parameters to split by sector, so have to sum + if "sector" in consumption.dims: + consumption = consumption.sum("sector") + + # Split by timeslice + if timeslice_shares_path is not None: + shares = read_timeslice_shares(timeslice_shares_path) + consumption = broadcast_timeslice(consumption) * shares + else: + consumption = distribute_timeslice(consumption) + + return consumption + + +def read_timeslice_shares(path: Path) -> xr.DataArray: + """Reads and processes timeslice shares data from a CSV file.""" + df = read_timeslice_shares_csv(path) + return process_timeslice_shares(df) + + +def read_timeslice_shares_csv(path: Path) -> pd.DataFrame: + """Reads sliceshare information into a DataFrame.""" + data = read_csv( + path, + required_columns=["region", "timeslice"], + msg=f"Reading timeslice shares from {path}.", + ) + + return data + + +def process_timeslice_shares(data: pd.DataFrame) -> xr.DataArray: + """Processes timeslice shares DataFrame into an xarray DataArray.""" + from muse.commodities import COMMODITIES + from muse.timeslices import TIMESLICE + + # Extract commodity columns + commodities = [c for c in data.columns if c in COMMODITIES.commodity.values] + + # Convert commodity columns to long format (i.e. single "commodity" column) + data = data.melt( + id_vars=["region", "timeslice"], + value_vars=commodities, + var_name="commodity", + value_name="value", + ) + + # Create multiindex for region and timeslice + data = create_multiindex( + data, + index_columns=["region", "timeslice", "commodity"], + index_names=["region", "timeslice", "commodity"], + drop_columns=True, + ) + + # Create DataArray + result = create_xarray_dataset(data).value.astype(float) + + # Assign timeslices + result = result.assign_coords(timeslice=TIMESLICE.timeslice) + + # Check commodities + result = check_commodities(result, fill_missing=True, fill_value=0) + return result + + +def read_macro_drivers(path: Path) -> pd.DataFrame: + """Reads and processes macro drivers data from a CSV file.""" + df = read_macro_drivers_csv(path) + return process_macro_drivers(df) + + +def read_macro_drivers_csv(path: Path) -> pd.DataFrame: + """Reads a standard MUSE csv file for macro drivers into a DataFrame.""" + table = read_csv( + path, + required_columns=["region", "variable"], + msg=f"Reading macro drivers from {path}.", + ) + + # Validate required variables + required_variables = ["Population", "GDP|PPP"] + missing_variables = [ + var for var in required_variables if var not in table.variable.unique() + ] + if missing_variables: + raise ValueError(f"Missing required variables in {path}: {missing_variables}") + + return table + + +def process_macro_drivers(data: pd.DataFrame) -> xr.Dataset: + """Processes macro drivers DataFrame into an xarray Dataset.""" + # Drop unit column if present + if "unit" in data.columns: + data = data.drop(columns=["unit"]) + + # Select year columns + year_columns = [col for col in data.columns if col.isdigit()] + + # Convert year columns to long format (i.e. single "year" column) + data = data.melt( + id_vars=["variable", "region"], + value_vars=year_columns, + var_name="year", + value_name="value", + ) + + # Pivot data to create Population and GDP|PPP columns + data = data.pivot( + index=["region", "year"], + columns="variable", + values="value", + ) + + # Legacy: rename Population to population and GDP|PPP to gdp + if "Population" in data.columns: + data = data.rename(columns={"Population": "population"}) + if "GDP|PPP" in data.columns: + data = data.rename(columns={"GDP|PPP": "gdp"}) + + # Create DataSet + result = create_xarray_dataset(data) + return result + + +def read_regression_parameters(path: Path) -> xr.Dataset: + """Reads and processes regression parameters from a CSV file.""" + df = read_regression_parameters_csv(path) + return process_regression_parameters(df) + + +def read_regression_parameters_csv(path: Path) -> pd.DataFrame: + """Reads the regression parameters from a MUSE csv file into a DataFrame.""" + table = read_csv( + path, + required_columns=["region", "function_type", "coeff"], + msg=f"Reading regression parameters from {path}.", + ) + + # Legacy: warn about "sector" column + if "sector" in table.columns: + getLogger(__name__).warning( + f"The sector column (in file {path}) is deprecated. Please remove." + ) + + return table + + +def process_regression_parameters(data: pd.DataFrame) -> xr.Dataset: + """Processes regression parameters DataFrame into an xarray Dataset.""" + from muse.commodities import COMMODITIES + + # Extract commodity columns + commodities = [c for c in data.columns if c in COMMODITIES.commodity.values] + + # Melt to long format + melted = data.melt( + id_vars=["sector", "region", "function_type", "coeff"], + value_vars=commodities, + var_name="commodity", + value_name="value", + ) + + # Extract sector -> function_type mapping + sector_to_ftype = melted.drop_duplicates(["sector", "function_type"])[ + ["sector", "function_type"] + ].set_index("sector")["function_type"] + + # Pivot to create coefficient variables + pivoted = melted.pivot_table( + index=["sector", "region", "commodity"], columns="coeff", values="value" + ) + + # Create dataset and add function_type + result = create_xarray_dataset(pivoted) + result["function_type"] = xr.DataArray( + sector_to_ftype[result.sector.values].astype(object), + dims=["sector"], + name="function_type", + ) + + # Check commodities + result = check_commodities(result, fill_missing=True, fill_value=0) + return result diff --git a/src/muse/readers/csv/general.py b/src/muse/readers/csv/general.py new file mode 100644 index 000000000..f1bf99657 --- /dev/null +++ b/src/muse/readers/csv/general.py @@ -0,0 +1,49 @@ +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .helpers import create_xarray_dataset, read_csv + + +def read_attribute_table(path: Path) -> xr.Dataset: + """Reads and processes attribute table data from a CSV file.""" + df = read_attribute_table_csv(path) + return process_attribute_table(df) + + +def read_attribute_table_csv(path: Path) -> pd.DataFrame: + """Read a standard MUSE csv file for price projections into a DataFrame.""" + table = read_csv( + path, + required_columns=["region", "attribute", "year"], + msg=f"Reading attribute table from {path}.", + ) + return table + + +def process_attribute_table(data: pd.DataFrame) -> xr.Dataset: + """Process attribute table DataFrame into an xarray Dataset.""" + # Extract commodity columns + commodities = [ + col for col in data.columns if col not in ["region", "year", "attribute"] + ] + + # Convert commodity columns to long format (i.e. single "commodity" column) + data = data.melt( + id_vars=["region", "year", "attribute"], + value_vars=commodities, + var_name="commodity", + value_name="value", + ) + + # Pivot data over attributes + data = data.pivot( + index=["region", "year", "commodity"], + columns="attribute", + values="value", + ) + + # Create DataSet + result = create_xarray_dataset(data) + return result diff --git a/src/muse/readers/csv/helpers.py b/src/muse/readers/csv/helpers.py new file mode 100644 index 000000000..e74e4611f --- /dev/null +++ b/src/muse/readers/csv/helpers.py @@ -0,0 +1,346 @@ +"""Helper functions for reading and processing CSV files.""" + +from __future__ import annotations + +from logging import getLogger +from pathlib import Path + +import pandas as pd +import xarray as xr + +from muse.utilities import camel_to_snake + +# Global mapping of column names to their standardized versions +# This is for backwards compatibility with old file formats +COLUMN_RENAMES = { + "process_name": "technology", + "process": "technology", + "sector_name": "sector", + "region_name": "region", + "time": "year", + "commodity_name": "commodity", + "comm_type": "commodity_type", + "commodity_price": "prices", + "units_commodity_price": "units_prices", + "enduse": "end_use", + "sn": "timeslice", + "commodity_emission_factor_CO2": "emmission_factor", + "utilisation_factor": "utilization_factor", + "objsort": "obj_sort", + "objsort1": "obj_sort1", + "objsort2": "obj_sort2", + "objsort3": "obj_sort3", + "time_slice": "timeslice", + "price": "prices", +} + +# Columns who's values should be converted from camelCase to snake_case +CAMEL_TO_SNAKE_COLUMNS = [ + "tech_type", + "commodity", + "commodity_type", + "agent_share", + "attribute", + "sector", + "region", + "parameter", +] + +# Global mapping of column names to their expected types +COLUMN_TYPES = { + "year": int, + "region": str, + "technology": str, + "commodity": str, + "sector": str, + "attribute": str, + "variable": str, + "timeslice": int, # For tables that require int timeslice instead of month etc. + "name": str, + "commodity_type": str, + "tech_type": str, + "type": str, + "function_type": str, + "level": str, + "search_rule": str, + "decision_method": str, + "quantity": float, + "share": float, + "coeff": str, + "value": float, + "utilization_factor": float, + "minimum_service_factor": float, + "maturity_threshold": float, + "spend_limit": float, + "prices": float, + "emmission_factor": float, +} + +DEFAULTS = { + "cap_par": 0, + "cap_exp": 1, + "fix_par": 0, + "fix_exp": 1, + "var_par": 0, + "var_exp": 1, + "interest_rate": 0, + "utilization_factor": 1, + "minimum_service_factor": 0, + "search_rule": "all", + "decision_method": "single", +} + + +def standardize_columns(data: pd.DataFrame) -> pd.DataFrame: + """Standardizes column names in a DataFrame. + + This function: + 1. Converts column names to snake_case + 2. Applies the global COLUMN_RENAMES mapping + 3. Preserves any columns not in the mapping + + Args: + data: DataFrame to standardize + + Returns: + DataFrame with standardized column names + """ + # Drop index column if present + if data.columns[0] == "" or data.columns[0].startswith("Unnamed"): + data = data.iloc[:, 1:] + + # Convert columns to snake_case + data = data.rename(columns=camel_to_snake) + + # Then apply global mapping + data = data.rename(columns=COLUMN_RENAMES) + + # Make sure there are no duplicate columns + if len(data.columns) != len(set(data.columns)): + raise ValueError(f"Duplicate columns in {data.columns}") + + return data + + +def create_multiindex( + data: pd.DataFrame, + index_columns: list[str], + index_names: list[str], + drop_columns: bool = True, +) -> pd.DataFrame: + """Creates a MultiIndex from specified columns. + + Args: + data: DataFrame to create index from + index_columns: List of column names to use for index + index_names: List of names for the index levels + drop_columns: Whether to drop the original columns + + Returns: + DataFrame with new MultiIndex + """ + index = pd.MultiIndex.from_arrays( + [data[col] for col in index_columns], names=index_names + ) + result = data.copy() + result.index = index + if drop_columns: + result = result.drop(columns=index_columns) + return result + + +def create_xarray_dataset( + data: pd.DataFrame, + disallow_nan: bool = True, +) -> xr.Dataset: + """Creates an xarray Dataset from a DataFrame with standardized options. + + Args: + data: DataFrame to convert + disallow_nan: Whether to raise an error if NaN values are found + + Returns: + xarray Dataset + """ + result = xr.Dataset.from_dataframe(data) + if disallow_nan: + nan_coords = get_nan_coordinates(result) + if nan_coords: + raise ValueError(f"Missing data for coordinates: {nan_coords}") + + if "year" in result.coords: + result = result.assign_coords(year=result.year.astype(int)) + result = result.sortby("year") + assert len(set(result.year.values)) == result.year.data.size # no duplicates + + return result + + +def get_nan_coordinates(dataset: xr.Dataset) -> list[tuple]: + """Get coordinates of a Dataset where any data variable has NaN values.""" + any_nan = sum(var.isnull() for var in dataset.data_vars.values()) + if any_nan.any(): + return any_nan.where(any_nan, drop=True).to_dataframe(name="").index.to_list() + return [] + + +def convert_column_types(data: pd.DataFrame) -> pd.DataFrame: + """Converts DataFrame columns to their expected types. + + Args: + data: DataFrame to convert + + Returns: + DataFrame with converted column types + """ + result = data.copy() + for column, expected_type in COLUMN_TYPES.items(): + if column in result.columns: + try: + if expected_type is int: + result[column] = pd.to_numeric(result[column], downcast="integer") + elif expected_type is float: + result[column] = pd.to_numeric(result[column]).astype(float) + elif expected_type is str: + result[column] = result[column].astype(str) + except (ValueError, TypeError) as e: + raise ValueError( + f"Could not convert column '{column}' to {expected_type.__name__}: {e}" # noqa: E501 + ) + return result + + +def standardize_dataframe( + data: pd.DataFrame, + required_columns: list[str] | None = None, + exclude_extra_columns: bool = False, +) -> pd.DataFrame: + """Standardizes a DataFrame to a common format. + + Args: + data: DataFrame to standardize + required_columns: List of column names that must be present (optional) + exclude_extra_columns: If True, exclude any columns not in required_columns list + (optional). This can be important if extra columns can mess up the resulting + xarray object. + + Returns: + DataFrame containing the standardized data + """ + if required_columns is None: + required_columns = [] + + # Standardize column names + data = standardize_columns(data) + + # Convert specified column values from camelCase to snake_case + for col in CAMEL_TO_SNAKE_COLUMNS: + if col in data.columns: + data[col] = data[col].apply(camel_to_snake) + + # Fill missing values with defaults + data = data.fillna(DEFAULTS) + for col, default in DEFAULTS.items(): + if col not in data.columns and col in required_columns: + data[col] = default + + # Check/convert data types + data = convert_column_types(data) + + # Validate required columns if provided + if required_columns: + missing_columns = [col for col in required_columns if col not in data.columns] + if missing_columns: + raise ValueError(f"Missing required columns: {missing_columns}") + + # Exclude extra columns if requested + if exclude_extra_columns: + data = data[list(required_columns)] + + return data + + +def read_csv( + path: Path, + float_precision: str = "high", + required_columns: list[str] | None = None, + exclude_extra_columns: bool = False, + msg: str | None = None, +) -> pd.DataFrame: + """Reads and standardizes a CSV file into a DataFrame. + + Args: + path: Path to the CSV file + float_precision: Precision to use when reading floats + required_columns: List of column names that must be present (optional) + exclude_extra_columns: If True, exclude any columns not in required_columns list + (optional). This can be important if extra columns can mess up the resulting + xarray object. + msg: Message to log (optional) + + Returns: + DataFrame containing the standardized data + """ + # Log message + if msg: + getLogger(__name__).info(msg) + + # Check if file exists + if not path.is_file(): + raise OSError(f"{path} does not exist.") + + # Check if there's a units row (in which case we need to skip it) + with open(path) as f: + next(f) # Skip header row + first_data_row = f.readline().strip() + skiprows = [1] if first_data_row.startswith("Unit") else None + + # Read the file + data = pd.read_csv( + path, + float_precision=float_precision, + low_memory=False, + skiprows=skiprows, + ) + + # Standardize the DataFrame + return standardize_dataframe( + data, + required_columns=required_columns, + exclude_extra_columns=exclude_extra_columns, + ) + + +def check_commodities( + data: xr.Dataset | xr.DataArray, fill_missing: bool = True, fill_value: float = 0 +) -> xr.Dataset | xr.DataArray: + """Validates and optionally fills missing commodities in data.""" + from muse.commodities import COMMODITIES + + # Make sure there are no commodities in data but not in global commodities + extra_commodities = [ + c for c in data.commodity.values if c not in COMMODITIES.commodity.values + ] + if extra_commodities: + raise ValueError( + "The following commodities were not found in global commodities file: " + f"{extra_commodities}" + ) + + # Add any missing commodities with fill_value + if fill_missing: + data = data.reindex( + commodity=COMMODITIES.commodity.values, fill_value=fill_value + ) + return data + + +def create_assets(data: xr.DataArray | xr.Dataset) -> xr.DataArray | xr.Dataset: + """Creates assets from technology data.""" + # Rename technology to asset + result = data.drop_vars("technology").rename(technology="asset") + result["technology"] = "asset", data.technology.values + + # Add installed year + result["installed"] = ("asset", [int(result.year.min())] * len(result.technology)) + return result diff --git a/src/muse/readers/csv/market.py b/src/muse/readers/csv/market.py new file mode 100644 index 000000000..add03faee --- /dev/null +++ b/src/muse/readers/csv/market.py @@ -0,0 +1,141 @@ +"""Reads and processes initial market data. + +The data is shared between sectors, so we only do this once per simulation. + +This data is contained in three csv files: +- projections: contains price projections for commodities +- base year import (optional): contains imports for commodities +- base year export (optional): contains exports for commodities + +A single xarray Dataset is returned, with dimensions for `region`, `year`, `commodity`, +and `timeslice`, and variables for `prices`, `exports`, `imports`, and `static_trade`. +""" + +from __future__ import annotations + +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .general import process_attribute_table +from .helpers import check_commodities, read_csv + + +def read_initial_market( + projections_path: Path, + base_year_import_path: Path | None = None, + base_year_export_path: Path | None = None, + currency: str | None = None, +) -> xr.Dataset: + """Reads and processes initial market data. + + Args: + projections_path: path to the projections file + base_year_import_path: path to the base year import file (optional) + base_year_export_path: path to the base year export file (optional) + currency: currency string (e.g. "USD") + + Returns: + xr.Dataset: Dataset containing initial market data. + """ + # Read projections + projections_df = read_projections_csv(projections_path) + + # Read base year export (optional) + if base_year_export_path: + export_df = read_csv( + base_year_export_path, + msg=f"Reading base year export from {base_year_export_path}.", + ) + else: + export_df = None + + # Read base year import (optional) + if base_year_import_path: + import_df = read_csv( + base_year_import_path, + msg=f"Reading base year import from {base_year_import_path}.", + ) + else: + import_df = None + + # Assemble into xarray Dataset + result = process_initial_market(projections_df, import_df, export_df, currency) + return result + + +def read_projections_csv(path: Path) -> pd.DataFrame: + """Reads projections data from a CSV file.""" + required_columns = { + "region", + "attribute", + "year", + } + projections_df = read_csv( + path, required_columns=required_columns, msg=f"Reading projections from {path}." + ) + return projections_df + + +def process_initial_market( + projections_df: pd.DataFrame, + import_df: pd.DataFrame | None, + export_df: pd.DataFrame | None, + currency: str | None = None, +) -> xr.Dataset: + """Process market data DataFrames into an xarray Dataset. + + Args: + projections_df: DataFrame containing projections data + import_df: Optional DataFrame containing import data + export_df: Optional DataFrame containing export data + currency: Currency string (e.g. "USD") + """ + from muse.commodities import COMMODITIES + from muse.timeslices import broadcast_timeslice, distribute_timeslice + + # Process projections + projections = process_attribute_table(projections_df).commodity_price.astype( + "float64" + ) + + # Process optional trade data + if export_df is not None: + base_year_export = process_attribute_table(export_df).exports.astype("float64") + else: + base_year_export = xr.zeros_like(projections) + + if import_df is not None: + base_year_import = process_attribute_table(import_df).imports.astype("float64") + else: + base_year_import = xr.zeros_like(projections) + + # Distribute data over timeslices + projections = broadcast_timeslice(projections, level=None) + base_year_export = distribute_timeslice(base_year_export, level=None) + base_year_import = distribute_timeslice(base_year_import, level=None) + + # Assemble into xarray + result = xr.Dataset( + { + "prices": projections, + "exports": base_year_export, + "imports": base_year_import, + "static_trade": base_year_import - base_year_export, + } + ) + + # Check commodities + result = check_commodities(result, fill_missing=True, fill_value=0) + + # Add units_prices coordinate + # Only added if the currency is specified and commodity units are defined + if currency and "unit" in COMMODITIES.data_vars: + units_prices = [ + f"{currency}/{COMMODITIES.sel(commodity=c).unit.item()}" + for c in result.commodity.values + ] + result = result.assign_coords(units_prices=("commodity", units_prices)) + + return result diff --git a/src/muse/readers/csv/presets.py b/src/muse/readers/csv/presets.py new file mode 100644 index 000000000..abb68a6ac --- /dev/null +++ b/src/muse/readers/csv/presets.py @@ -0,0 +1,114 @@ +"""Reads and processes preset data from multiple CSV files. + +This runs once per sector, reading in csv files and outputting an xarray Dataset. +""" + +from logging import getLogger +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .helpers import ( + check_commodities, + create_multiindex, + create_xarray_dataset, + read_csv, +) + + +def read_presets(presets_paths: Path) -> xr.Dataset: + """Reads and processes preset data from multiple CSV files. + + Accepts a path pattern for presets files, e.g. `Path("path/to/*Consumption.csv")`. + The file name of each file must contain a year (e.g. "2020Consumption.csv"). + """ + from glob import glob + from re import match + + # Find all files matching the path pattern + allfiles = [Path(p) for p in glob(str(presets_paths))] + if len(allfiles) == 0: + raise OSError(f"No files found with paths {presets_paths}") + + # Read all files + datas: dict[int, pd.DataFrame] = {} + for path in allfiles: + # Extract year from filename + reyear = match(r"\S*.(\d{4})\S*\.csv", path.name) + if reyear is None: + raise OSError(f"Unexpected filename {path.name}") + year = int(reyear.group(1)) + if year in datas: + raise OSError(f"Year f{year} was found twice") + + # Read data + data = read_presets_csv(path) + data["year"] = year + datas[year] = data + + # Process data + datas = process_presets(datas) + return datas + + +def read_presets_csv(path: Path) -> pd.DataFrame: + data = read_csv( + path, + required_columns=["region", "timeslice"], + msg=f"Reading presets from {path}.", + ) + + # Legacy: drop technology column and sum data (PR #448) + if "technology" in data.columns: + getLogger(__name__).warning( + f"The technology (or ProcessName) column in file {path} is " + "deprecated. Data has been summed across technologies, and this column " + "has been dropped." + ) + data = ( + data.drop(columns=["technology"]) + .groupby(["region", "timeslice"]) + .sum() + .reset_index() + ) + + return data + + +def process_presets(datas: dict[int, pd.DataFrame]) -> xr.Dataset: + """Processes preset DataFrames into an xarray Dataset.""" + from muse.commodities import COMMODITIES + from muse.timeslices import TIMESLICE + + # Combine into a single DataFrame + data = pd.concat(datas.values()) + + # Extract commodity columns + commodities = [c for c in data.columns if c in COMMODITIES.commodity.values] + + # Convert commodity columns to long format (i.e. single "commodity" column) + data = data.melt( + id_vars=["region", "year", "timeslice"], + value_vars=commodities, + var_name="commodity", + value_name="value", + ) + + # Create multiindex for region, year, timeslice and commodity + data = create_multiindex( + data, + index_columns=["region", "year", "timeslice", "commodity"], + index_names=["region", "year", "timeslice", "commodity"], + drop_columns=True, + ) + + # Create DataArray + result = create_xarray_dataset(data).value.astype(float) + + # Assign timeslices + result = result.assign_coords(timeslice=TIMESLICE.timeslice) + + # Check commodities + result = check_commodities(result, fill_missing=True, fill_value=0) + return result diff --git a/src/muse/readers/csv/technologies.py b/src/muse/readers/csv/technologies.py new file mode 100644 index 000000000..a786a06d6 --- /dev/null +++ b/src/muse/readers/csv/technologies.py @@ -0,0 +1,381 @@ +"""Reads and processes technology data from multiple CSV files. + +This runs once per sector, reading in csv files and outputting an xarray Dataset. + +Several csv files are read in: +- technodictionary: contains technology parameters +- comm_out: contains output commodity data +- comm_in: contains input commodity data +- technodata_timeslices (optional): allows some parameters to be defined per timeslice +""" + +from __future__ import annotations + +from logging import getLogger +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .helpers import ( + check_commodities, + create_multiindex, + create_xarray_dataset, + read_csv, +) + + +def read_technologies( + technodata_path: Path, + comm_out_path: Path, + comm_in_path: Path, + time_framework: list[int], + interpolation_mode: str = "linear", + technodata_timeslices_path: Path | None = None, +) -> xr.Dataset: + """Reads and processes technology data from multiple CSV files. + + Will also interpolate data to the time framework if provided. + + Args: + technodata_path: path to the technodata file + comm_out_path: path to the comm_out file + comm_in_path: path to the comm_in file + time_framework: list of years to interpolate data to + interpolation_mode: Interpolation mode to use + technodata_timeslices_path: path to the technodata_timeslices file + + Returns: + xr.Dataset: Dataset containing the processed technology data. Any fields + that differ by year will contain a "year" dimension interpolated to the + time framework. Other fields will not have a "year" dimension. + """ + # Read all data + technodata = read_technodictionary(technodata_path) + comm_out = read_io_technodata(comm_out_path) + comm_in = read_io_technodata(comm_in_path) + technodata_timeslices = ( + read_technodata_timeslices(technodata_timeslices_path) + if technodata_timeslices_path + else None + ) + + # Assemble xarray Dataset + return process_technologies( + technodata, + comm_out, + comm_in, + time_framework, + interpolation_mode, + technodata_timeslices, + ) + + +def read_technodictionary(path: Path) -> xr.Dataset: + """Reads and processes technodictionary data from a CSV file.""" + df = read_technodictionary_csv(path) + return process_technodictionary(df) + + +def read_technodictionary_csv(path: Path) -> pd.DataFrame: + """Reads and formats technodata into a DataFrame.""" + required_columns = { + "cap_exp", + "region", + "var_par", + "fix_exp", + "interest_rate", + "utilization_factor", + "minimum_service_factor", + "year", + "cap_par", + "var_exp", + "technology", + "technical_life", + "fix_par", + } + data = read_csv( + path, + required_columns=required_columns, + msg=f"Reading technodictionary from {path}.", + ) + + # Check for deprecated columns + if "fuel" in data.columns: + msg = ( + f"The 'fuel' column in {path} has been deprecated. " + "This information is now determined from CommIn files. " + "Please remove this column from your Technodata files." + ) + getLogger(__name__).warning(msg) + if "end_use" in data.columns: + msg = ( + f"The 'end_use' column in {path} has been deprecated. " + "This information is now determined from CommOut files. " + "Please remove this column from your Technodata files." + ) + getLogger(__name__).warning(msg) + if "scaling_size" in data.columns: + msg = ( + f"The 'scaling_size' column in {path} has been deprecated. " + "Please remove this column from your Technodata files." + ) + getLogger(__name__).warning(msg) + + return data + + +def process_technodictionary(data: pd.DataFrame) -> xr.Dataset: + """Processes technodictionary DataFrame into an xarray Dataset.""" + # Create multiindex for technology and region + data = create_multiindex( + data, + index_columns=["technology", "region", "year"], + index_names=["technology", "region", "year"], + drop_columns=True, + ) + + # Create dataset + result = create_xarray_dataset(data) + + # Handle tech_type if present + if "type" in result.variables: + result["tech_type"] = result.type.isel(region=0, year=0) + + return result + + +def read_technodata_timeslices(path: Path) -> xr.Dataset: + """Reads and processes technodata timeslices from a CSV file.""" + df = read_technodata_timeslices_csv(path) + return process_technodata_timeslices(df) + + +def read_technodata_timeslices_csv(path: Path) -> pd.DataFrame: + """Reads and formats technodata timeslices into a DataFrame.""" + from muse.timeslices import TIMESLICE + + timeslice_columns = set(TIMESLICE.coords["timeslice"].indexes["timeslice"].names) + required_columns = { + "utilization_factor", + "technology", + "minimum_service_factor", + "region", + "year", + } | timeslice_columns + return read_csv( + path, + required_columns=required_columns, + exclude_extra_columns=True, + msg=f"Reading technodata timeslices from {path}.", + ) + + +def process_technodata_timeslices(data: pd.DataFrame) -> xr.Dataset: + """Processes technodata timeslices DataFrame into an xarray Dataset.""" + from muse.timeslices import TIMESLICE, sort_timeslices + + # Create multiindex for all columns except factor columns + factor_columns = ["utilization_factor", "minimum_service_factor", "obj_sort"] + index_columns = [col for col in data.columns if col not in factor_columns] + data = create_multiindex( + data, + index_columns=index_columns, + index_names=index_columns, + drop_columns=True, + ) + + # Create dataset + result = create_xarray_dataset(data) + + # Stack timeslice levels (month, day, hour) into a single timeslice dimension + timeslice_levels = TIMESLICE.coords["timeslice"].indexes["timeslice"].names + if all(level in result.dims for level in timeslice_levels): + result = result.stack(timeslice=timeslice_levels) + return sort_timeslices(result) + + +def read_io_technodata(path: Path) -> xr.Dataset: + """Reads and processes input/output technodata from a CSV file.""" + df = read_io_technodata_csv(path) + return process_io_technodata(df) + + +def read_io_technodata_csv(path: Path) -> pd.DataFrame: + """Reads process inputs or outputs into a DataFrame.""" + data = read_csv( + path, + required_columns=["technology", "region", "year"], + msg=f"Reading IO technodata from {path}.", + ) + + # Unspecified Level values default to "fixed" + if "level" in data.columns: + data["level"] = data["level"].fillna("fixed") + else: + # Particularly relevant to outputs files where the Level column is omitted by + # default, as only "fixed" outputs are allowed. + data["level"] = "fixed" + + return data + + +def process_io_technodata(data: pd.DataFrame) -> xr.Dataset: + """Processes IO technodata DataFrame into an xarray Dataset.""" + from muse.commodities import COMMODITIES + + # Extract commodity columns + commodities = [c for c in data.columns if c in COMMODITIES.commodity.values] + + # Convert commodity columns to long format (i.e. single "commodity" column) + data = data.melt( + id_vars=["technology", "region", "year", "level"], + value_vars=commodities, + var_name="commodity", + value_name="value", + ) + + # Pivot data to create fixed and flexible columns + data = data.pivot( + index=["technology", "region", "year", "commodity"], + columns="level", + values="value", + ) + + # Create xarray dataset + result = create_xarray_dataset(data) + + # Fill in flexible data + if "flexible" in result.data_vars: + result["flexible"] = result.flexible.fillna(0) + else: + result["flexible"] = xr.zeros_like(result.fixed).rename("flexible") + + # Check commodities + result = check_commodities(result, fill_missing=True, fill_value=0) + return result + + +def process_technologies( + technodata: xr.Dataset, + comm_out: xr.Dataset, + comm_in: xr.Dataset, + time_framework: list[int], + interpolation_mode: str = "linear", + technodata_timeslices: xr.Dataset | None = None, +) -> xr.Dataset: + """Processes technology data DataFrames into an xarray Dataset.""" + from muse.commodities import COMMODITIES, CommodityUsage + from muse.timeslices import drop_timeslice + from muse.utilities import interpolate_technodata + + # Process inputs/outputs + ins = comm_in.rename(flexible="flexible_inputs", fixed="fixed_inputs") + outs = comm_out.rename(flexible="flexible_outputs", fixed="fixed_outputs") + + # Legacy: Remove flexible outputs + if not (outs["flexible_outputs"] == 0).all(): + raise ValueError( + "'flexible' outputs are not permitted. All outputs must be 'fixed'" + ) + outs = outs.drop_vars("flexible_outputs") + + # Collect all years from the time framework and data files + time_framework = list( + set(time_framework).union( + technodata.year.values.tolist(), + ins.year.values.tolist(), + outs.year.values.tolist(), + technodata_timeslices.year.values.tolist() if technodata_timeslices else [], + ) + ) + + # Interpolate data to match the time framework + technodata = interpolate_technodata(technodata, time_framework, interpolation_mode) + outs = interpolate_technodata(outs, time_framework, interpolation_mode) + ins = interpolate_technodata(ins, time_framework, interpolation_mode) + if technodata_timeslices: + technodata_timeslices = interpolate_technodata( + technodata_timeslices, time_framework, interpolation_mode + ) + + # Merge inputs/outputs with technodata + technodata = technodata.merge(outs).merge(ins) + + # Merge technodata_timeslices if provided. This will prioritise values defined in + # technodata_timeslices, and fallback to the non-timesliced technodata for any + # values that are not defined in technodata_timeslices. + if technodata_timeslices: + technodata["utilization_factor"] = ( + technodata_timeslices.utilization_factor.combine_first( + technodata.utilization_factor + ) + ) + technodata["minimum_service_factor"] = drop_timeslice( + technodata_timeslices.minimum_service_factor.combine_first( + technodata.minimum_service_factor + ) + ) + + # Check commodities + technodata = check_commodities(technodata, fill_missing=False) + + # Add info about commodities + technodata = technodata.merge(COMMODITIES.sel(commodity=technodata.commodity)) + + # Add commodity usage flags + technodata["comm_usage"] = ( + "commodity", + CommodityUsage.from_technologies(technodata).values, + ) + technodata = technodata.drop_vars("commodity_type") + + # Check utilization and minimum service factors + check_utilization_and_minimum_service_factors(technodata) + + return technodata + + +def check_utilization_and_minimum_service_factors(data: xr.Dataset) -> None: + """Check utilization and minimum service factors in an xarray dataset. + + Args: + data: xarray Dataset containing utilization_factor and minimum_service_factor + """ + if "utilization_factor" not in data.data_vars: + raise ValueError( + "A technology needs to have a utilization factor defined for every " + "timeslice." + ) + + # Check UF not all zero (sum across timeslice dimension if it exists) + if "timeslice" in data.dims: + utilization_sum = data.utilization_factor.sum(dim="timeslice") + else: + utilization_sum = data.utilization_factor + + if (utilization_sum == 0).any(): + raise ValueError( + "A technology can not have a utilization factor of 0 for every timeslice." + ) + + # Check UF in range + utilization = data.utilization_factor + if not ((utilization >= 0) & (utilization <= 1)).all(): + raise ValueError( + "Utilization factor values must all be between 0 and 1 inclusive." + ) + + # Check MSF in range + min_service_factor = data.minimum_service_factor + if not ((min_service_factor >= 0) & (min_service_factor <= 1)).all(): + raise ValueError( + "Minimum service factor values must all be between 0 and 1 inclusive." + ) + + # Check UF not below MSF + if (data.utilization_factor < data.minimum_service_factor).any(): + raise ValueError( + "Utilization factors must all be greater than or equal " + "to their corresponding minimum service factors." + ) diff --git a/src/muse/readers/csv/trade_assets.py b/src/muse/readers/csv/trade_assets.py new file mode 100644 index 000000000..652f1ad3b --- /dev/null +++ b/src/muse/readers/csv/trade_assets.py @@ -0,0 +1,66 @@ +"""Reads and processes existing trade data from a CSV file. + +We only use this for trade sectors, otherwise we use read_assets instead. +""" + +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .helpers import ( + create_assets, + create_multiindex, + create_xarray_dataset, + read_csv, +) + + +def read_trade_assets(path: Path) -> xr.DataArray: + """Reads and processes existing trade data from a CSV file.""" + df = read_existing_trade_csv(path) + return process_existing_trade(df) + + +def read_existing_trade_csv(path: Path) -> pd.DataFrame: + required_columns = { + "region", + "technology", + "year", + } + return read_csv( + path, + required_columns=required_columns, + msg=f"Reading existing trade from {path}.", + ) + + +def process_existing_trade(data: pd.DataFrame) -> xr.DataArray: + # Select region columns + # TODO: this is a bit unsafe as user could supply other columns + regions = [ + col for col in data.columns if col not in ["technology", "region", "year"] + ] + + # Melt data over regions + data = data.melt( + id_vars=["technology", "region", "year"], + value_vars=regions, + var_name="dst_region", + value_name="value", + ) + + # Create multiindex for region, dst_region, technology and year + data = create_multiindex( + data, + index_columns=["region", "dst_region", "technology", "year"], + index_names=["region", "dst_region", "technology", "year"], + drop_columns=True, + ) + + # Create DataArray + result = create_xarray_dataset(data).value.astype(float) + + # Create assets from technologies + result = create_assets(result) + return result diff --git a/src/muse/readers/csv/trade_technodata.py b/src/muse/readers/csv/trade_technodata.py new file mode 100644 index 000000000..69240e952 --- /dev/null +++ b/src/muse/readers/csv/trade_technodata.py @@ -0,0 +1,60 @@ +"""Reads and processes trade technodata from a CSV file. + +We only use this for trade sectors. In this case, it gets added on the the dataset +created by `read_technologies`. +""" + +from pathlib import Path + +import pandas as pd +import xarray as xr + +from .helpers import ( + create_xarray_dataset, + read_csv, +) + + +def read_trade_technodata(path: Path) -> xr.Dataset: + """Reads and processes trade technodata from a CSV file.""" + df = read_trade_technodata_csv(path) + return process_trade_technodata(df) + + +def read_trade_technodata_csv(path: Path) -> pd.DataFrame: + required_columns = {"technology", "region", "parameter"} + return read_csv( + path, + required_columns=required_columns, + msg=f"Reading trade technodata from {path}.", + ) + + +def process_trade_technodata(data: pd.DataFrame) -> xr.Dataset: + # Drop unit column if present + if "unit" in data.columns: + data = data.drop(columns=["unit"]) + + # Select region columns + # TODO: this is a bit unsafe as user could supply other columns + regions = [ + col for col in data.columns if col not in ["technology", "region", "parameter"] + ] + + # Melt data over regions + data = data.melt( + id_vars=["technology", "region", "parameter"], + value_vars=regions, + var_name="dst_region", + value_name="value", + ) + + # Pivot data over parameters + data = data.pivot( + index=["technology", "region", "dst_region"], + columns="parameter", + values="value", + ) + + # Create DataSet + return create_xarray_dataset(data) diff --git a/src/muse/readers/toml.py b/src/muse/readers/toml.py index 5c49320d5..06020ebd0 100644 --- a/src/muse/readers/toml.py +++ b/src/muse/readers/toml.py @@ -556,7 +556,7 @@ def check_subsector_settings(settings: dict) -> None: getLogger(__name__).warning(msg) -def read_technodata( +def read_sector_technodata( settings: Any, sector_name: str, interpolation_mode: str = "linear", @@ -646,7 +646,11 @@ def read_presets_sector(settings: Any, sector_name: str) -> xr.Dataset: xr.Dataset: Dataset containing consumption and supply data for the sector. Costs are initialized to zero. """ - from muse.readers import read_attribute_table, read_presets + from muse.readers import ( + read_attribute_table, + read_correlation_consumption, + read_presets, + ) from muse.timeslices import distribute_timeslice, drop_timeslice sector_conf = getattr(settings.sectors, sector_name) @@ -662,7 +666,11 @@ def read_presets_sector(settings: Any, sector_name: str) -> xr.Dataset: getattr(sector_conf, "macrodrivers_path", None) is not None and getattr(sector_conf, "regression_path", None) is not None ): - consumption = read_correlation_consumption(sector_conf) + consumption = read_correlation_consumption( + macro_drivers_path=sector_conf.macrodrivers_path, + regression_path=sector_conf.regression_path, + timeslice_shares_path=getattr(sector_conf, "timeslice_shares_path", None), + ) else: raise MissingSettings(f"Missing consumption data for sector {sector_name}") @@ -678,51 +686,3 @@ def read_presets_sector(settings: Any, sector_name: str) -> xr.Dataset: ) return presets - - -def read_correlation_consumption(sector_conf: Any) -> xr.Dataset: - """Read consumption data for a sector based on correlation files. - - This function calculates endogenous demand for a sector using macro drivers and - regression parameters. It applies optional filters, handles sector aggregation, - and distributes the consumption across timeslices if timeslice shares are provided. - - Args: - sector_conf: Sector configuration object containing paths to macro drivers, - regression parameters, and timeslice shares files - - Returns: - xr.Dataset: Consumption data distributed across timeslices and regions - """ - from muse.readers import ( - read_macro_drivers, - read_regression_parameters, - read_timeslice_shares, - ) - from muse.regressions import endogenous_demand - from muse.timeslices import broadcast_timeslice, distribute_timeslice - - macro_drivers = read_macro_drivers(sector_conf.macrodrivers_path) - regression_parameters = read_regression_parameters(sector_conf.regression_path) - consumption = endogenous_demand( - drivers=macro_drivers, - regression_parameters=regression_parameters, - forecast=0, - ) - - # Legacy: apply filters - if hasattr(sector_conf, "filters"): - consumption = consumption.sel(sector_conf.filters._asdict()) - - # Legacy: we permit regression parameters to split by sector, so have to sum - if "sector" in consumption.dims: - consumption = consumption.sum("sector") - - # Split by timeslice - if sector_conf.timeslice_shares_path is not None: - shares = read_timeslice_shares(sector_conf.timeslice_shares_path) - consumption = broadcast_timeslice(consumption) * shares - else: - consumption = distribute_timeslice(consumption) - - return consumption diff --git a/src/muse/regressions.py b/src/muse/regressions.py index 39560abe3..26dc2cd2e 100644 --- a/src/muse/regressions.py +++ b/src/muse/regressions.py @@ -48,7 +48,7 @@ class Regression(Callable): `xarray.Dataset` as input. In any case, it is given the gpd and population. These can be read from standard MUSE csv files: - >>> from muse.readers import read_macro_drivers + >>> from muse.readers.csv.correlation_consumption import read_macro_drivers >>> from muse.defaults import DATA_DIRECTORY >>> path_to_macrodrivers = DATA_DIRECTORY / "Macrodrivers.csv" >>> if path_to_macrodrivers.exists(): @@ -117,20 +117,15 @@ def _split_kwargs(data: Dataset, **kwargs) -> tuple[Mapping, Mapping]: @classmethod def factory( cls, - regression_data: str | Path | Dataset, + regression_data: Dataset, interpolation: str = "linear", base_year: int = 2010, **filters, ) -> Regression: """Creates a regression function from standard muse input.""" - from muse.readers import read_regression_parameters - assert cls.__mappings__ assert cls.__regression__ != "" - if isinstance(regression_data, (str, Path)): - regression_data = read_regression_parameters(regression_data) - # Get the parameters of interest with a 'simple' name coeffs = Dataset({k: regression_data[v] for k, v in cls.__mappings__.items()}) filters.update(coeffs.data_vars) @@ -138,15 +133,10 @@ def factory( def factory( - regression_parameters: str | Path | Dataset, + regression_parameters: Dataset, sector: str | Sequence[str] | None = None, ) -> Regression: """Creates regression functor from standard MUSE data for given sector.""" - from muse.readers import read_regression_parameters - - if isinstance(regression_parameters, (str, Path)): - regression_parameters = read_regression_parameters(regression_parameters) - if sector is not None: regression_parameters = regression_parameters.sel(sector=sector) @@ -203,7 +193,6 @@ def register_regression( registered regression functor. """ from logging import getLogger - from pathlib import Path from muse.registration import name_variations @@ -504,15 +493,11 @@ def __call__( def endogenous_demand( - regression_parameters: str | Path | Dataset, - drivers: str | Path | Dataset, + regression_parameters: Dataset, + drivers: Dataset, sector: str | Sequence | None = None, **kwargs, ) -> Dataset: """Endogenous demand based on macro drivers and regression parameters.""" - from muse.readers import read_macro_drivers - regression = factory(regression_parameters, sector=sector) - if isinstance(drivers, (str, Path)): - drivers = read_macro_drivers(drivers) return regression(drivers, **kwargs) diff --git a/src/muse/sectors/sector.py b/src/muse/sectors/sector.py index 6d7efa077..b12f6e912 100644 --- a/src/muse/sectors/sector.py +++ b/src/muse/sectors/sector.py @@ -11,7 +11,7 @@ from muse.agents import AbstractAgent from muse.production import PRODUCTION_SIGNATURE -from muse.readers.toml import read_technodata +from muse.readers.toml import read_sector_technodata from muse.sectors.abstract import AbstractSector from muse.sectors.register import register_sector from muse.sectors.subsector import Subsector @@ -44,7 +44,7 @@ def factory(cls, name: str, settings: Any) -> Sector: interactions_config = sector_settings.get("interactions", None) # Read technologies - technologies = read_technodata( + technologies = read_sector_technodata( settings, name, interpolation_mode=interpolation_mode, diff --git a/src/muse/sectors/subsector.py b/src/muse/sectors/subsector.py index b7903178c..14f2698aa 100644 --- a/src/muse/sectors/subsector.py +++ b/src/muse/sectors/subsector.py @@ -118,7 +118,7 @@ def factory( from muse import investments as iv from muse.agents import InvestingAgent, agents_factory from muse.commodities import is_enduse - from muse.readers import read_csv, read_existing_trade, read_initial_capacity + from muse.readers import read_assets, read_csv, read_trade_assets # Read existing capacity or existing trade file # Have to peek at the file to determine what format the data is in @@ -126,9 +126,9 @@ def factory( # the parameter name in the settings file df = read_csv(settings.existing_capacity) if "year" not in df.columns: - existing_capacity = read_initial_capacity(settings.existing_capacity) + existing_capacity = read_assets(settings.existing_capacity) else: - existing_capacity = read_existing_trade(settings.existing_capacity) + existing_capacity = read_trade_assets(settings.existing_capacity) # Create agents agents = agents_factory( diff --git a/tests/conftest.py b/tests/conftest.py index 3cebe1449..379640a1d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,8 +11,10 @@ from pytest import fixture from xarray import DataArray, Dataset +from muse import examples from muse.__main__ import patched_broadcast_compat_data from muse.agents import Agent +from muse.readers.toml import read_settings @contextmanager @@ -598,3 +600,48 @@ def rng(request): from numpy.random import default_rng return default_rng(getattr(request.config.option, "randomly_seed", None)) + + +@fixture +def default_model_path(tmp_path): + """Creates temporary folder containing the default model.""" + examples.copy_model(name="default", path=tmp_path) + path = tmp_path / "model" + read_settings(path / "settings.toml") # setup globals + return path + + +@fixture +def default_timeslice_model_path(tmp_path): + """Creates temporary folder containing the default model.""" + examples.copy_model(name="default_timeslice", path=tmp_path) + path = tmp_path / "model" + read_settings(path / "settings.toml") # setup globals + return path + + +@fixture +def default_retro_model_path(tmp_path): + """Creates temporary folder containing the default_retro model.""" + examples.copy_model(name="default_retro", path=tmp_path) + path = tmp_path / "model" + read_settings(path / "settings.toml") # setup globals + return path + + +@fixture +def trade_model_path(tmp_path): + """Creates temporary folder containing the trade model.""" + examples.copy_model(name="trade", path=tmp_path) + path = tmp_path / "model" + read_settings(path / "settings.toml") # setup globals + return path + + +@fixture +def default_correlation_model_path(tmp_path): + """Creates temporary folder containing the correlation model.""" + examples.copy_model(name="default_correlation", path=tmp_path) + path = tmp_path / "model" + read_settings(path / "settings.toml") # setup globals + return path diff --git a/tests/test_csv_readers.py b/tests/test_csv_readers.py index d02c68125..8eb3de513 100644 --- a/tests/test_csv_readers.py +++ b/tests/test_csv_readers.py @@ -4,9 +4,7 @@ import numpy as np import xarray as xr -from pytest import fixture -from muse import examples from muse.readers.toml import read_settings # Common test data @@ -127,65 +125,11 @@ def assert_single_coordinate(data, selection, expected): ) -@fixture -def timeslice(): - """Sets up global timeslicing scheme to match the default model.""" - from muse.timeslices import setup_module - - timeslice = """ - [timeslices] - all-year.all-week.night = 1460 - all-year.all-week.morning = 1460 - all-year.all-week.afternoon = 1460 - all-year.all-week.early-peak = 1460 - all-year.all-week.late-peak = 1460 - all-year.all-week.evening = 1460 - """ - - setup_module(timeslice) - - -@fixture -def model_path(tmp_path): - """Creates temporary folder containing the default model.""" - examples.copy_model(name="default", path=tmp_path) - path = tmp_path / "model" - read_settings(path / "settings.toml") # setup globals - return path - - -@fixture -def timeslice_model_path(tmp_path): - """Creates temporary folder containing the default model.""" - examples.copy_model(name="default_timeslice", path=tmp_path) - path = tmp_path / "model" - read_settings(path / "settings.toml") # setup globals - return path +def test_read_global_commodities(default_model_path): + from muse.readers.csv import read_commodities - -@fixture -def trade_model_path(tmp_path): - """Creates temporary folder containing the trade model.""" - examples.copy_model(name="trade", path=tmp_path) - path = tmp_path / "model" - read_settings(path / "settings.toml") # setup globals - return path - - -@fixture -def correlation_model_path(tmp_path): - """Creates temporary folder containing the correlation model.""" - examples.copy_model(name="default_correlation", path=tmp_path) - path = tmp_path / "model" - read_settings(path / "settings.toml") # setup globals - return path - - -def test_read_global_commodities(model_path): - from muse.readers.csv import read_global_commodities - - path = model_path / "GlobalCommodities.csv" - data = read_global_commodities(path) + path = default_model_path / "GlobalCommodities.csv" + data = read_commodities(path) # Check data against schema expected_schema = DatasetSchema( @@ -210,10 +154,10 @@ def test_read_global_commodities(model_path): assert_single_coordinate(data, coord, expected) -def test_read_presets(model_path): +def test_read_presets(default_model_path): from muse.readers.csv import read_presets - data = read_presets(str(model_path / "residential_presets" / "*.csv")) + data = read_presets(str(default_model_path / "residential_presets" / "*.csv")) # Check data against schema expected_schema = DataArraySchema( @@ -255,10 +199,12 @@ def test_read_presets(model_path): ) -def test_read_initial_market(model_path): +def test_read_initial_market(default_model_path): from muse.readers.csv import read_initial_market - data = read_initial_market(model_path / "Projections.csv", currency="MUS$2010") + data = read_initial_market( + default_model_path / "Projections.csv", currency="MUS$2010" + ) # Check data against schema expected_schema = DatasetSchema( @@ -318,10 +264,10 @@ def test_read_initial_market(model_path): assert_single_coordinate(data, coord, expected) -def test_read_technodictionary(model_path): - from muse.readers.csv import read_technodictionary +def test_read_technodictionary(default_model_path): + from muse.readers.csv.technologies import read_technodictionary - data = read_technodictionary(model_path / "power" / "Technodata.csv") + data = read_technodictionary(default_model_path / "power" / "Technodata.csv") # Check data against schema expected_schema = DatasetSchema( @@ -388,11 +334,11 @@ def test_read_technodictionary(model_path): assert_single_coordinate(data, coord, expected) -def test_read_technodata_timeslices(timeslice_model_path): - from muse.readers.csv import read_technodata_timeslices +def test_read_technodata_timeslices(default_timeslice_model_path): + from muse.readers.csv.technologies import read_technodata_timeslices data = read_technodata_timeslices( - timeslice_model_path / "power" / "TechnodataTimeslices.csv" + default_timeslice_model_path / "power" / "TechnodataTimeslices.csv" ) # Check data against schema @@ -438,10 +384,10 @@ def test_read_technodata_timeslices(timeslice_model_path): assert_single_coordinate(data, coord, expected) -def test_read_io_technodata(model_path): - from muse.readers.csv import read_io_technodata +def test_read_io_technodata(default_model_path): + from muse.readers.csv.technologies import read_io_technodata - data = read_io_technodata(model_path / "power" / "CommIn.csv") + data = read_io_technodata(default_model_path / "power" / "CommIn.csv") # Check data against schema expected_schema = DatasetSchema( @@ -476,10 +422,10 @@ def test_read_io_technodata(model_path): assert_single_coordinate(data, coord, expected) -def test_read_initial_capacity(model_path): - from muse.readers.csv import read_initial_capacity +def test_read_existing_capacity(default_model_path): + from muse.readers.csv import read_assets - data = read_initial_capacity(model_path / "power" / "ExistingCapacity.csv") + data = read_assets(default_model_path / "power" / "ExistingCapacity.csv") # Check data against schema expected_schema = DataArraySchema( @@ -512,10 +458,10 @@ def test_read_initial_capacity(model_path): assert data.sel(region="r1", asset=0, year=2020).item() == 1 -def test_read_agent_parameters(model_path): - from muse.readers.csv import read_agent_parameters +def test_read_agents(default_model_path): + from muse.readers.csv import read_agents - data = read_agent_parameters(model_path / "Agents.csv") + data = read_agents(default_model_path / "Agents.csv") assert isinstance(data, list) assert len(data) == 1 @@ -536,9 +482,9 @@ def test_read_agent_parameters(model_path): def test_read_existing_trade(trade_model_path): - from muse.readers.csv import read_existing_trade + from muse.readers.csv import read_trade_assets - data = read_existing_trade(trade_model_path / "gas" / "ExistingTrade.csv") + data = read_trade_assets(trade_model_path / "gas" / "ExistingTrade.csv") # Check data against schema expected_schema = DataArraySchema( @@ -616,11 +562,13 @@ def test_read_trade_technodata(trade_model_path): assert_single_coordinate(data, coord, expected) -def test_read_timeslice_shares(correlation_model_path): - from muse.readers.csv import read_timeslice_shares +def test_read_timeslice_shares(default_correlation_model_path): + from muse.readers.csv.correlation_consumption import read_timeslice_shares data = read_timeslice_shares( - correlation_model_path / "residential_presets" / "TimesliceSharepreset.csv" + default_correlation_model_path + / "residential_presets" + / "TimesliceSharepreset.csv" ) # Check data against schema @@ -658,11 +606,11 @@ def test_read_timeslice_shares(correlation_model_path): assert data.sel(**coord).item() == 0.071 -def test_read_macro_drivers(correlation_model_path): - from muse.readers.csv import read_macro_drivers +def test_read_macro_drivers(default_correlation_model_path): + from muse.readers.csv.correlation_consumption import read_macro_drivers data = read_macro_drivers( - correlation_model_path / "residential_presets" / "Macrodrivers.csv" + default_correlation_model_path / "residential_presets" / "Macrodrivers.csv" ) # Check data against schema @@ -697,11 +645,13 @@ def test_read_macro_drivers(correlation_model_path): assert_single_coordinate(data, coord, expected) -def test_read_regression_parameters(correlation_model_path): - from muse.readers.csv import read_regression_parameters +def test_read_regression_parameters(default_correlation_model_path): + from muse.readers.csv.correlation_consumption import read_regression_parameters data = read_regression_parameters( - correlation_model_path / "residential_presets" / "regressionparameters.csv" + default_correlation_model_path + / "residential_presets" + / "regressionparameters.csv" ) # Check data against schema @@ -747,14 +697,14 @@ def test_read_regression_parameters(correlation_model_path): assert_single_coordinate(data, coord, expected) -def test_read_technologies(model_path): +def test_read_technologies(default_model_path): from muse.readers.csv import read_technologies # Read technologies data = read_technologies( - technodata_path=model_path / "power" / "Technodata.csv", - comm_out_path=model_path / "power" / "CommOut.csv", - comm_in_path=model_path / "power" / "CommIn.csv", + technodata_path=default_model_path / "power" / "Technodata.csv", + comm_out_path=default_model_path / "power" / "CommOut.csv", + comm_in_path=default_model_path / "power" / "CommIn.csv", time_framework=[2020, 2025, 2030, 2035, 2040, 2045, 2050], interpolation_mode="linear", ) @@ -803,15 +753,15 @@ def test_read_technologies(model_path): ) -def test_read_technologies__timeslice(timeslice_model_path): +def test_read_technologies__timeslice(default_timeslice_model_path): """Testing the read_technologies function with the timeslice model.""" from muse.readers.csv import read_technologies data = read_technologies( - technodata_path=timeslice_model_path / "power" / "Technodata.csv", - comm_out_path=timeslice_model_path / "power" / "CommOut.csv", - comm_in_path=timeslice_model_path / "power" / "CommIn.csv", - technodata_timeslices_path=timeslice_model_path + technodata_path=default_timeslice_model_path / "power" / "Technodata.csv", + comm_out_path=default_timeslice_model_path / "power" / "CommOut.csv", + comm_in_path=default_timeslice_model_path / "power" / "CommIn.csv", + technodata_timeslices_path=default_timeslice_model_path / "power" / "TechnodataTimeslices.csv", time_framework=[2020, 2025, 2030, 2035, 2040, 2045, 2050], @@ -868,11 +818,11 @@ def test_read_technologies__timeslice(timeslice_model_path): ) -def test_read_technodata(model_path): - from muse.readers.toml import read_settings, read_technodata +def test_read_sector_technodata(default_model_path): + from muse.readers.toml import read_sector_technodata, read_settings - settings = read_settings(model_path / "settings.toml") - data = read_technodata( + settings = read_settings(default_model_path / "settings.toml") + data = read_sector_technodata( settings, sector_name="power", interpolation_mode="linear", @@ -923,12 +873,12 @@ def test_read_technodata(model_path): ) -def test_read_technodata__trade(trade_model_path): +def test_read_sector_technodata__trade(trade_model_path): """Testing the read_technodata function with the trade model.""" - from muse.readers.toml import read_settings, read_technodata + from muse.readers.toml import read_sector_technodata, read_settings settings = read_settings(trade_model_path / "settings.toml") - data = read_technodata( + data = read_sector_technodata( settings, sector_name="power", interpolation_mode="linear", @@ -981,10 +931,10 @@ def test_read_technodata__trade(trade_model_path): ) -def test_read_presets_sector(model_path): - from muse.readers.toml import read_presets_sector, read_settings +def test_read_presets_sector(default_model_path): + from muse.readers.toml import read_presets_sector - settings = read_settings(model_path / "settings.toml") + settings = read_settings(default_model_path / "settings.toml") data = read_presets_sector(settings, sector_name="residential_presets") # Check data against schema @@ -1029,11 +979,11 @@ def test_read_presets_sector(model_path): assert_single_coordinate(data, coord, expected) -def test_read_presets_sector__correlation(correlation_model_path): +def test_read_presets_sector__correlation(default_correlation_model_path): """Testing the read_presets_sector function with the correlation model.""" - from muse.readers.toml import read_presets_sector, read_settings + from muse.readers.toml import read_presets_sector - settings = read_settings(correlation_model_path / "settings.toml") + settings = read_settings(default_correlation_model_path / "settings.toml") data = read_presets_sector(settings, sector_name="residential_presets") # Check data against schema diff --git a/tests/test_read_csv.py b/tests/test_read_csv.py index a09e10c22..83b7d89e6 100644 --- a/tests/test_read_csv.py +++ b/tests/test_read_csv.py @@ -1,53 +1,11 @@ -import pytest - -from muse import examples -from muse.readers.csv import ( - read_agent_parameters_csv, - read_existing_trade_csv, - read_global_commodities_csv, - read_initial_capacity_csv, - read_macro_drivers_csv, - read_presets_csv, - read_projections_csv, - read_regression_parameters_csv, - read_technodata_timeslices_csv, - read_technodictionary_csv, - read_timeslice_shares_csv, - read_trade_technodata_csv, -) - - -@pytest.fixture -def model_path(tmp_path): - """Creates temporary folder containing the default model.""" - examples.copy_model(name="default", path=tmp_path) - return tmp_path / "model" - - -@pytest.fixture -def timeslice_model_path(tmp_path): - """Creates temporary folder containing the default model.""" - examples.copy_model(name="default_timeslice", path=tmp_path) - return tmp_path / "model" - - -@pytest.fixture -def trade_model_path(tmp_path): - """Creates temporary folder containing the default model.""" - examples.copy_model(name="trade", path=tmp_path) - return tmp_path / "model" - - -@pytest.fixture -def correlation_model_path(tmp_path): - """Creates temporary folder containing the correlation model.""" - examples.copy_model(name="default_correlation", path=tmp_path) - return tmp_path / "model" - - -def test_read_technodictionary_csv(model_path): +from __future__ import annotations + + +def test_read_technodictionary_csv(default_model_path): """Test reading the technodictionary CSV file.""" - technodictionary_path = model_path / "power" / "Technodata.csv" + from muse.readers.csv.technologies import read_technodictionary_csv + + technodictionary_path = default_model_path / "power" / "Technodata.csv" technodictionary_df = read_technodictionary_csv(technodictionary_path) assert technodictionary_df is not None mandatory_columns = { @@ -74,9 +32,13 @@ def test_read_technodictionary_csv(model_path): assert set(technodictionary_df.columns) == mandatory_columns | extra_columns -def test_read_technodata_timeslices_csv(timeslice_model_path): +def test_read_technodata_timeslices_csv(default_timeslice_model_path): """Test reading the technodata timeslices CSV file.""" - timeslices_path = timeslice_model_path / "power" / "TechnodataTimeslices.csv" + from muse.readers.csv.technologies import read_technodata_timeslices_csv + + timeslices_path = ( + default_timeslice_model_path / "power" / "TechnodataTimeslices.csv" + ) timeslices_df = read_technodata_timeslices_csv(timeslices_path) mandatory_columns = { "utilization_factor", @@ -93,10 +55,12 @@ def test_read_technodata_timeslices_csv(timeslice_model_path): assert set(timeslices_df.columns) == mandatory_columns | extra_columns -def test_read_initial_capacity_csv(model_path): +def test_read_existing_capacity_csv(default_model_path): """Test reading the initial capacity CSV file.""" - capacity_path = model_path / "power" / "ExistingCapacity.csv" - capacity_df = read_initial_capacity_csv(capacity_path) + from muse.readers.csv.assets import read_existing_capacity_csv + + capacity_path = default_model_path / "power" / "ExistingCapacity.csv" + capacity_df = read_existing_capacity_csv(capacity_path) mandatory_columns = { "region", "technology", @@ -113,9 +77,11 @@ def test_read_initial_capacity_csv(model_path): assert set(capacity_df.columns) == mandatory_columns | extra_columns -def test_read_global_commodities_csv(model_path): +def test_read_global_commodities_csv(default_model_path): """Test reading the global commodities CSV file.""" - commodities_path = model_path / "GlobalCommodities.csv" + from muse.readers.csv.commodities import read_global_commodities_csv + + commodities_path = default_model_path / "GlobalCommodities.csv" commodities_df = read_global_commodities_csv(commodities_path) mandatory_columns = { "commodity", @@ -125,10 +91,14 @@ def test_read_global_commodities_csv(model_path): assert set(commodities_df.columns) == mandatory_columns | extra_columns -def test_read_timeslice_shares_csv(correlation_model_path): +def test_read_timeslice_shares_csv(default_correlation_model_path): """Test reading the timeslice shares CSV file.""" + from muse.readers.csv.correlation_consumption import read_timeslice_shares_csv + shares_path = ( - correlation_model_path / "residential_presets" / "TimesliceSharepreset.csv" + default_correlation_model_path + / "residential_presets" + / "TimesliceSharepreset.csv" ) shares_df = read_timeslice_shares_csv(shares_path) mandatory_columns = { @@ -145,10 +115,12 @@ def test_read_timeslice_shares_csv(correlation_model_path): assert set(shares_df.columns) == mandatory_columns | extra_columns -def test_read_agent_parameters_csv(model_path): +def test_read_agents_csv(default_model_path): """Test reading the agent parameters CSV file.""" - agents_path = model_path / "Agents.csv" - agents_df = read_agent_parameters_csv(agents_path) + from muse.readers.csv.agents import read_agents_csv + + agents_path = default_model_path / "Agents.csv" + agents_df = read_agents_csv(agents_path) mandatory_columns = { "search_rule", "quantity", @@ -166,9 +138,13 @@ def test_read_agent_parameters_csv(model_path): assert set(agents_df.columns) == mandatory_columns | extra_columns -def test_read_macro_drivers_csv(correlation_model_path): +def test_read_macro_drivers_csv(default_correlation_model_path): """Test reading the macro drivers CSV file.""" - macro_path = correlation_model_path / "residential_presets" / "Macrodrivers.csv" + from muse.readers.csv.correlation_consumption import read_macro_drivers_csv + + macro_path = ( + default_correlation_model_path / "residential_presets" / "Macrodrivers.csv" + ) macro_df = read_macro_drivers_csv(macro_path) mandatory_columns = { "region", @@ -183,9 +159,11 @@ def test_read_macro_drivers_csv(correlation_model_path): assert "GDP|PPP" in macro_df["variable"].values -def test_read_projections_csv(model_path): +def test_read_projections_csv(default_model_path): """Test reading the projections CSV file.""" - projections_path = model_path / "Projections.csv" + from muse.readers.csv.market import read_projections_csv + + projections_path = default_model_path / "Projections.csv" projections_df = read_projections_csv(projections_path) mandatory_columns = { "year", @@ -200,10 +178,14 @@ def test_read_projections_csv(model_path): assert set(projections_df.columns) == mandatory_columns | extra_columns -def test_read_regression_parameters_csv(correlation_model_path): +def test_read_regression_parameters_csv(default_correlation_model_path): """Test reading the regression parameters CSV file.""" + from muse.readers.csv.correlation_consumption import read_regression_parameters_csv + regression_path = ( - correlation_model_path / "residential_presets" / "regressionparameters.csv" + default_correlation_model_path + / "residential_presets" + / "regressionparameters.csv" ) regression_df = read_regression_parameters_csv(regression_path) mandatory_columns = { @@ -221,9 +203,13 @@ def test_read_regression_parameters_csv(correlation_model_path): assert set(regression_df.columns) == mandatory_columns | extra_columns -def test_read_presets_csv(model_path): +def test_read_presets_csv(default_model_path): """Test reading the presets CSV files.""" - presets_path = model_path / "residential_presets" / "Residential2020Consumption.csv" + from muse.readers.csv.presets import read_presets_csv + + presets_path = ( + default_model_path / "residential_presets" / "Residential2020Consumption.csv" + ) presets_df = read_presets_csv(presets_path) mandatory_columns = { @@ -238,6 +224,8 @@ def test_read_presets_csv(model_path): def test_read_existing_trade_csv(trade_model_path): """Test reading the existing trade CSV file.""" + from muse.readers.csv.trade_assets import read_existing_trade_csv + trade_path = trade_model_path / "power" / "ExistingTrade.csv" trade_df = read_existing_trade_csv(trade_path) mandatory_columns = { @@ -251,6 +239,8 @@ def test_read_existing_trade_csv(trade_model_path): def test_read_trade_technodata(trade_model_path): """Test reading the trade technodata CSV file.""" + from muse.readers.csv.trade_technodata import read_trade_technodata_csv + trade_path = trade_model_path / "power" / "TradeTechnodata.csv" trade_df = read_trade_technodata_csv(trade_path) mandatory_columns = {"technology", "region", "parameter"} diff --git a/tests/test_readers.py b/tests/test_readers.py index 924dcacff..cae57d94d 100644 --- a/tests/test_readers.py +++ b/tests/test_readers.py @@ -187,7 +187,9 @@ def test_suffix_path_formatting(suffix, tmp_path): def test_check_utilization_and_minimum_service(): """Test combined validation of utilization and minimum service factors.""" - from muse.readers.csv import check_utilization_and_minimum_service_factors + from muse.readers.csv.technologies import ( + check_utilization_and_minimum_service_factors, + ) # Test valid case - create dataset with proper dimensions ds = xr.Dataset( @@ -240,7 +242,9 @@ def test_check_utilization_and_minimum_service(): def test_check_utilization_not_all_zero_fail(): """Test validation fails when all utilization factors are zero.""" - from muse.readers.csv import check_utilization_and_minimum_service_factors + from muse.readers.csv.technologies import ( + check_utilization_and_minimum_service_factors, + ) ds = xr.Dataset( { @@ -260,7 +264,9 @@ def test_check_utilization_not_all_zero_fail(): ) def test_check_utilization_in_range_fail(values): """Test validation fails for utilization factors outside valid range.""" - from muse.readers.csv import check_utilization_and_minimum_service_factors + from muse.readers.csv.technologies import ( + check_utilization_and_minimum_service_factors, + ) ds = xr.Dataset( { @@ -277,7 +283,7 @@ def test_check_utilization_in_range_fail(values): def test_get_nan_coordinates(): """Test get_nan_coordinates for various scenarios.""" - from muse.readers.csv import get_nan_coordinates + from muse.readers.csv.helpers import get_nan_coordinates # Test 1: Explicit NaN values df1 = pd.DataFrame( diff --git a/tests/test_subsector.py b/tests/test_subsector.py index 29affb24b..93eae1fb3 100644 --- a/tests/test_subsector.py +++ b/tests/test_subsector.py @@ -8,7 +8,7 @@ from muse import examples from muse import investments as iv from muse.agents.factories import create_agent -from muse.readers import read_agent_parameters, read_initial_capacity +from muse.readers import read_agents, read_assets from muse.readers.toml import read_settings from muse.sectors.subsector import Subsector, aggregate_enduses @@ -41,8 +41,8 @@ def agent_params(model, tmp_path, technologies): """Common agent parameters setup.""" examples.copy_model(model, tmp_path) path = tmp_path / "model" / "Agents.csv" - params = read_agent_parameters(path) - capa = read_initial_capacity(path.with_name("residential") / "ExistingCapacity.csv") + params = read_agents(path) + capa = read_assets(path.with_name("residential") / "ExistingCapacity.csv") for param in params: param.update( diff --git a/tests/test_wizard.py b/tests/test_wizard.py index da3e29ca3..1dcdfd07a 100644 --- a/tests/test_wizard.py +++ b/tests/test_wizard.py @@ -1,10 +1,8 @@ from pathlib import Path import pandas as pd -import pytest from tomlkit import dumps, parse -from muse import examples from muse.wizard import ( add_agent, add_new_commodity, @@ -17,20 +15,6 @@ ) -@pytest.fixture -def model_path(tmp_path): - """Creates temporary folder containing the default model.""" - examples.copy_model(name="default", path=tmp_path) - return tmp_path / "model" - - -@pytest.fixture -def model_path_retro(tmp_path): - """Creates temporary folder containing the default_retro model.""" - examples.copy_model(name="default_retro", path=tmp_path) - return tmp_path / "model" - - def assert_values_in_csv(file_path: Path, column: str, expected_values: list): """Helper function to check if values exist in a CSV column.""" df = pd.read_csv(file_path) @@ -75,19 +59,19 @@ def test_get_sectors(tmp_path): assert set(get_sectors(model_path)) == {"sector1", "sector2"} -def test_add_new_commodity(model_path): +def test_add_new_commodity(default_model_path): """Test the add_new_commodity function on the default model.""" - add_new_commodity(model_path, "new_commodity", "power", "wind") + add_new_commodity(default_model_path, "new_commodity", "power", "wind") # Check global commodities assert_values_in_csv( - model_path / "GlobalCommodities.csv", "commodity", ["new_commodity"] + default_model_path / "GlobalCommodities.csv", "commodity", ["new_commodity"] ) -def test_add_new_process(model_path): +def test_add_new_process(default_model_path): """Test the add_new_process function on the default model.""" - add_new_process(model_path, "new_process", "power", "windturbine") + add_new_process(default_model_path, "new_process", "power", "windturbine") files_to_check = [ "CommIn.csv", @@ -96,36 +80,44 @@ def test_add_new_process(model_path): "Technodata.csv", ] for file in files_to_check: - assert_values_in_csv(model_path / "power" / file, "technology", ["new_process"]) + assert_values_in_csv( + default_model_path / "power" / file, "technology", ["new_process"] + ) -def test_technodata_for_new_year(model_path): +def test_technodata_for_new_year(default_model_path): """Test the add_price_data_for_new_year function on the default model.""" - add_technodata_for_new_year(model_path, 2030, "power", 2020) + add_technodata_for_new_year(default_model_path, 2030, "power", 2020) - assert_values_in_csv(model_path / "power" / "Technodata.csv", "year", [2030]) + assert_values_in_csv( + default_model_path / "power" / "Technodata.csv", "year", [2030] + ) -def test_add_agent(model_path_retro): +def test_add_agent(default_retro_model_path): """Test the add_agent function on the default_retro model.""" - add_agent(model_path_retro, "A2", "A1", "Agent3", "Agent4") + add_agent(default_retro_model_path, "A2", "A1", "Agent3", "Agent4") # Check Agents.csv - assert_values_in_csv(model_path_retro / "Agents.csv", "name", ["A2"]) + assert_values_in_csv(default_retro_model_path / "Agents.csv", "name", ["A2"]) for share in ["Agent3", "Agent4"]: - assert_values_in_csv(model_path_retro / "Agents.csv", "agent_share", [share]) + assert_values_in_csv( + default_retro_model_path / "Agents.csv", "agent_share", [share] + ) # Check Technodata.csv files for sector in ["power", "gas"]: - assert_columns_exist(model_path_retro / sector / "Technodata.csv", ["Agent4"]) + assert_columns_exist( + default_retro_model_path / sector / "Technodata.csv", ["Agent4"] + ) -def test_add_region(model_path): +def test_add_region(default_model_path): """Test the add_region function on the default model.""" - add_region(model_path, "R2", "R1") + add_region(default_model_path, "R2", "R1") # Check settings.toml - with open(model_path / "settings.toml") as f: + with open(default_model_path / "settings.toml") as f: settings = parse(f.read()) assert "R2" in settings["regions"] @@ -136,17 +128,17 @@ def test_add_region(model_path): "CommOut.csv", "ExistingCapacity.csv", ] - for sector in get_sectors(model_path): + for sector in get_sectors(default_model_path): for file in files_to_check: - assert_values_in_csv(model_path / sector / file, "region", ["R2"]) + assert_values_in_csv(default_model_path / sector / file, "region", ["R2"]) -def test_add_timeslice(model_path): +def test_add_timeslice(default_model_path): """Test the add_timeslice function on the default model.""" - add_timeslice(model_path, "midnight", "evening") + add_timeslice(default_model_path, "midnight", "evening") # Check settings.toml - with open(model_path / "settings.toml") as f: + with open(default_model_path / "settings.toml") as f: settings = parse(f.read()) timeslices = settings["timeslices"]["all-year"]["all-week"] assert "midnight" in timeslices @@ -154,5 +146,5 @@ def test_add_timeslice(model_path): # Check preset files for preset in ["Residential2020Consumption.csv", "Residential2050Consumption.csv"]: - df = pd.read_csv(model_path / "residential_presets" / preset) + df = pd.read_csv(default_model_path / "residential_presets" / preset) assert len(df["timeslice"].unique()) == n_timeslices