diff --git a/src/muse/new_input/readers.py b/src/muse/new_input/readers.py new file mode 100644 index 000000000..c8833e902 --- /dev/null +++ b/src/muse/new_input/readers.py @@ -0,0 +1,388 @@ +import duckdb +import numpy as np +import pandas as pd +import xarray as xr + +from muse.timeslices import QuantityType + + +def read_inputs(data_dir): + data = {} + con = duckdb.connect(":memory:") + + with open(data_dir / "timeslices.csv") as f: + timeslices = read_timeslices_csv(f, con) + + with open(data_dir / "commodities.csv") as f: + commodities = read_commodities_csv(f, con) + + with open(data_dir / "regions.csv") as f: + regions = read_regions_csv(f, con) + + with open(data_dir / "commodity_trade.csv") as f: + commodity_trade = read_commodity_trade_csv(f, con) + + with open(data_dir / "commodity_costs.csv") as f: + commodity_costs = read_commodity_costs_csv(f, con) + + with open(data_dir / "demand.csv") as f: + demand = read_demand_csv(f, con) + + with open(data_dir / "demand_slicing.csv") as f: + demand_slicing = read_demand_slicing_csv(f, con) + + data["global_commodities"] = calculate_global_commodities(commodities) + data["demand"] = calculate_demand( + commodities, regions, timeslices, demand, demand_slicing + ) + data["initial_market"] = calculate_initial_market( + commodities, regions, timeslices, commodity_trade, commodity_costs + ) + return data + + +def read_timeslices_csv(buffer_, con): + sql = """CREATE TABLE timeslices ( + id BIGINT PRIMARY KEY, + month VARCHAR, + day VARCHAR, + hour VARCHAR, + fraction DOUBLE CHECK (fraction >= 0 AND fraction <= 1), + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("INSERT INTO timeslices SELECT id, month, day, hour, fraction FROM rel;") + return con.sql("SELECT * from timeslices").fetchnumpy() + + +def read_commodities_csv(buffer_, con): + sql = """CREATE TABLE commodities ( + id VARCHAR PRIMARY KEY, + type VARCHAR CHECK (type IN ('energy', 'service', 'material', 'environmental')), + unit VARCHAR, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("INSERT INTO commodities SELECT id, type, unit FROM rel;") + return con.sql("select * from commodities").fetchnumpy() + + +def read_regions_csv(buffer_, con): + sql = """CREATE TABLE regions ( + id VARCHAR PRIMARY KEY, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("INSERT INTO regions SELECT id FROM rel;") + return con.sql("SELECT * from regions").fetchnumpy() + + +def read_commodity_trade_csv(buffer_, con): + sql = """CREATE TABLE commodity_trade ( + commodity VARCHAR REFERENCES commodities(id), + region VARCHAR REFERENCES regions(id), + year BIGINT, + import DOUBLE, + export DOUBLE, + PRIMARY KEY (commodity, region, year) + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("""INSERT INTO commodity_trade SELECT + commodity_id, region_id, year, import, export FROM rel;""") + return con.sql("SELECT * from commodity_trade").fetchnumpy() + + +def read_commodity_costs_csv(buffer_, con): + sql = """CREATE TABLE commodity_costs ( + commodity VARCHAR REFERENCES commodities(id), + region VARCHAR REFERENCES regions(id), + year BIGINT, + value DOUBLE, + PRIMARY KEY (commodity, region, year) + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("""INSERT INTO commodity_costs SELECT + commodity_id, region_id, year, value FROM rel;""") + return con.sql("SELECT * from commodity_costs").fetchnumpy() + + +def read_demand_csv(buffer_, con): + sql = """CREATE TABLE demand ( + commodity VARCHAR REFERENCES commodities(id), + region VARCHAR REFERENCES regions(id), + year BIGINT, + demand DOUBLE, + PRIMARY KEY (commodity, region, year) + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("INSERT INTO demand SELECT commodity_id, region_id, year, demand FROM rel;") + return con.sql("SELECT * from demand").fetchnumpy() + + +def read_demand_slicing_csv(buffer_, con): + sql = """CREATE TABLE demand_slicing ( + commodity VARCHAR REFERENCES commodities(id), + region VARCHAR REFERENCES regions(id), + year BIGINT, + timeslice BIGINT REFERENCES timeslices(id), + fraction DOUBLE CHECK (fraction >= 0 AND fraction <= 1), + PRIMARY KEY (commodity, region, year, timeslice), + FOREIGN KEY (commodity, region, year) REFERENCES demand(commodity, region, year) + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("""INSERT INTO demand_slicing SELECT + commodity_id, region_id, year, timeslice_id, fraction FROM rel;""") + return con.sql("SELECT * from demand_slicing").fetchnumpy() + + +def calculate_global_commodities(commodities): + names = commodities["id"].astype(np.dtype("str")) + types = commodities["type"].astype(np.dtype("str")) + units = commodities["unit"].astype(np.dtype("str")) + + type_array = xr.DataArray( + data=types, dims=["commodity"], coords=dict(commodity=names) + ) + + unit_array = xr.DataArray( + data=units, dims=["commodity"], coords=dict(commodity=names) + ) + + data = xr.Dataset(data_vars=dict(type=type_array, unit=unit_array)) + return data + + +def calculate_demand( + commodities, regions, timeslices, demand, demand_slicing +) -> xr.DataArray: + """Calculate demand data for all commodities, regions, years, and timeslices. + + Result: A DataArray with a demand value for every combination of: + - commodity: all commodities specified in the commodities table + - region: all regions specified in the regions table + - year: all years specified in the demand table + - timeslice: all timeslices specified in the timeslices table + + Checks: + - If demand data is specified for one year, it must be specified for all years. + - If demand is nonzero, slicing data must be present. + - If slicing data is specified for a commodity/region/year, the sum of + the fractions must be 1, and all timeslices must be present. + + Fills: + - If demand data is not specified for a commodity/region combination, the demand is + 0 for all years and timeslices. + + Todo: + - Interpolation to allow for missing years in demand data. + - Ability to leave the year field blank in both tables to indicate all years + - Allow slicing data to be missing -> demand is spread equally across timeslices + - Allow more flexibility for timeslices (e.g. can specify "winter" to apply to all + winter timeslices, or "all" to apply to all timeslices) + """ + # Prepare dataframes + df_demand = pd.DataFrame(demand).set_index(["commodity", "region", "year"]) + df_slicing = pd.DataFrame(demand_slicing).set_index( + ["commodity", "region", "year", "timeslice"] + ) + + # DataArray dimensions + all_commodities = commodities["id"].astype(np.dtype("str")) + all_regions = regions["id"].astype(np.dtype("str")) + all_years = df_demand.index.get_level_values("year").unique() + all_timeslices = timeslices["id"].astype(np.dtype("int")) + + # CHECK: all years are specified for each commodity/region combination + check_all_values_specified(df_demand, ["commodity", "region"], "year", all_years) + + # CHECK: if slicing data is present, all timeslices must be specified + check_all_values_specified( + df_slicing, ["commodity", "region", "year"], "timeslice", all_timeslices + ) + + # CHECK: timeslice fractions sum to 1 + check_timeslice_sum = df_slicing.groupby(["commodity", "region", "year"]).apply( + lambda x: np.isclose(x["fraction"].sum(), 1) + ) + if not check_timeslice_sum.all(): + raise DataValidationError + + # CHECK: if demand data >0, fraction data must be specified + check_fraction_data_present = ( + df_demand[df_demand["demand"] > 0] + .index.isin(df_slicing.droplevel("timeslice").index) + .all() + ) + if not check_fraction_data_present.all(): + raise DataValidationError + + # FILL: demand is zero if unspecified + df_demand = df_demand.reindex( + pd.MultiIndex.from_product( + [all_commodities, all_regions, all_years], + names=["commodity", "region", "year"], + ), + fill_value=0, + ) + + # FILL: slice data is zero if unspecified + df_slicing = df_slicing.reindex( + pd.MultiIndex.from_product( + [all_commodities, all_regions, all_years, all_timeslices], + names=["commodity", "region", "year", "timeslice"], + ), + fill_value=0, + ) + + # Create DataArray + da_demand = df_demand.to_xarray()["demand"] + da_slicing = df_slicing.to_xarray()["fraction"] + data = da_demand * da_slicing + return data + + +def calculate_initial_market( + commodities, regions, timeslices, commodity_trade, commodity_costs +) -> xr.Dataset: + """Calculate trade and price data for all commodities, regions and years. + + Result: A Dataset with variables: + - prices + - exports + - imports + - static_trade + For every combination of: + - commodity: all commodities specified in the commodities table + - region: all regions specified in the regions table + - year: all years specified in the commodity_costs table + - timeslice (multiindex): all timeslices specified in the timeslices table + + Checks: + - If trade data is specified for one year, it must be specified for all years. + - If price data is specified for one year, it must be specified for all years. + + Fills: + - If trade data is not specified for a commodity/region combination, imports and + exports are both zero + - If price data is not specified for a commodity/region combination, the price is + zero + + Todo: + - Allow data to be specified on a timeslice level (optional) + - Interpolation, missing year field, flexible timeslice specification as above + + """ + # Prepare dataframes + df_trade = pd.DataFrame(commodity_trade).set_index(["commodity", "region", "year"]) + df_costs = ( + pd.DataFrame(commodity_costs) + .set_index(["commodity", "region", "year"]) + .rename(columns={"value": "prices"}) + ) + df_timeslices = pd.DataFrame(timeslices).set_index(["month", "day", "hour"]) + + # DataArray dimensions + all_commodities = commodities["id"].astype(np.dtype("str")) + all_regions = regions["id"].astype(np.dtype("str")) + all_years = df_costs.index.get_level_values("year").unique() + + # CHECK: all years are specified for each commodity/region combination + check_all_values_specified(df_trade, ["commodity", "region"], "year", all_years) + check_all_values_specified(df_costs, ["commodity", "region"], "year", all_years) + + # FILL: price is zero if unspecified + df_costs = df_costs.reindex( + pd.MultiIndex.from_product( + [all_commodities, all_regions, all_years], + names=["commodity", "region", "year"], + ), + fill_value=0, + ) + + # FILL: trade is zero if unspecified + df_trade = df_trade.reindex( + pd.MultiIndex.from_product( + [all_commodities, all_regions, all_years], + names=["commodity", "region", "year"], + ), + fill_value=0, + ) + + # Calculate static trade + df_trade["static_trade"] = df_trade["export"] - df_trade["import"] + + # Create xarray datasets + xr_costs = df_costs.to_xarray() + xr_trade = df_trade.to_xarray() + + # Project over timeslices + ts = df_timeslices.to_xarray()["fraction"].stack(timeslice=("month", "day", "hour")) + xr_costs = project_timeslice(xr_costs, ts, QuantityType.EXTENSIVE) + xr_trade = project_timeslice(xr_trade, ts, QuantityType.INTENSIVE) + + # Combine data + data = xr.merge([xr_costs, xr_trade]) + return data + + +class DataValidationError(ValueError): + pass + + +def check_all_values_specified( + df: pd.DataFrame, group_by_cols: list[str], column_name: str, values: list +) -> None: + """Check that the required values are specified in a dataframe. + + Checks that a row exists for all specified values of column_name for each + group in the grouped dataframe. + """ + if not ( + df.groupby(group_by_cols) + .apply( + lambda x: ( + set(x.index.get_level_values(column_name).unique()) == set(values) + ) + ) + .all() + ).all(): + msg = "" # TODO + raise DataValidationError(msg) + + +def project_timeslice( + data: xr.Dataset, timeslices: xr.DataArray, quantity_type: QuantityType +) -> xr.Dataset: + """Project a dataset over a new timeslice dimension. + + The projection can be done in one of two ways, depending on whether the + quantity type is extensive or intensive. See `QuantityType`. + + Args: + data: Dataset to project + timeslices: DataArray of timeslice levels, with values between 0 and 1 + representing the timeslice length (fraction of the year) + quantity_type: Type of projection to perform. QuantityType.EXTENSIVE or + QuantityType.INTENSIVE + + Returns: + Projected dataset + """ + assert "timeslice" in timeslices.dims + assert "timeslice" not in data.dims + + if quantity_type is QuantityType.INTENSIVE: + return data * timeslices + if quantity_type is QuantityType.EXTENSIVE: + return data * xr.ones_like(timeslices) diff --git a/tests/test_new_readers.py b/tests/test_new_readers.py new file mode 100644 index 000000000..e7d7e31d9 --- /dev/null +++ b/tests/test_new_readers.py @@ -0,0 +1,315 @@ +from io import StringIO + +import duckdb +import numpy as np +import xarray as xr +from pytest import approx, fixture, raises + + +@fixture +def default_new_input(tmp_path): + from muse.examples import copy_model + + copy_model("default_new_input", tmp_path) + return tmp_path / "model" + + +@fixture +def con(): + return duckdb.connect(":memory:") + + +@fixture +def populate_commodities(default_new_input, con): + from muse.new_input.readers import read_commodities_csv + + with open(default_new_input / "commodities.csv") as f: + return read_commodities_csv(f, con) + + +@fixture +def populate_commodity_trade( + default_new_input, con, populate_commodities, populate_regions +): + from muse.new_input.readers import read_commodity_trade_csv + + with open(default_new_input / "commodity_trade.csv") as f: + return read_commodity_trade_csv(f, con) + + +@fixture +def populate_commodity_costs( + default_new_input, con, populate_commodities, populate_regions +): + from muse.new_input.readers import read_commodity_costs_csv + + with open(default_new_input / "commodity_costs.csv") as f: + return read_commodity_costs_csv(f, con) + + +@fixture +def populate_demand(default_new_input, con, populate_regions, populate_commodities): + from muse.new_input.readers import read_demand_csv + + with open(default_new_input / "demand.csv") as f: + return read_demand_csv(f, con) + + +@fixture +def populate_demand_slicing( + default_new_input, + con, + populate_regions, + populate_commodities, + populate_demand, + populate_timeslices, +): + from muse.new_input.readers import read_demand_slicing_csv + + with open(default_new_input / "demand_slicing.csv") as f: + return read_demand_slicing_csv(f, con) + + +@fixture +def populate_regions(default_new_input, con): + from muse.new_input.readers import read_regions_csv + + with open(default_new_input / "regions.csv") as f: + return read_regions_csv(f, con) + + +@fixture +def populate_timeslices(default_new_input, con): + from muse.new_input.readers import read_timeslices_csv + + with open(default_new_input / "timeslices.csv") as f: + return read_timeslices_csv(f, con) + + +def test_read_timeslices_csv(populate_timeslices): + data = populate_timeslices + assert len(data["id"]) == 6 + assert next(iter(data["id"])) == 1 + assert next(iter(data["month"])) == "all-year" + assert next(iter(data["day"])) == "all-week" + assert next(iter(data["hour"])) == "night" + assert next(iter(data["fraction"])) == approx(0.1667) + + +def test_read_regions_csv(populate_regions): + assert populate_regions["id"] == np.array(["R1"]) + + +def test_read_commodities_csv(populate_commodities): + data = populate_commodities + assert list(data["id"]) == ["electricity", "gas", "heat", "wind", "CO2f"] + assert list(data["type"]) == ["energy"] * 5 + assert list(data["unit"]) == ["PJ"] * 4 + ["kt"] + + +def test_read_commodity_trade_csv(populate_commodity_trade): + data = populate_commodity_trade + assert data["commodity"].size == 0 + assert data["region"].size == 0 + assert data["year"].size == 0 + assert data["import"].size == 0 + assert data["export"].size == 0 + + +def test_read_commodity_costs_csv(populate_commodity_costs): + data = populate_commodity_costs + # Only checking the first element of each array, as the table is large + assert next(iter(data["commodity"])) == "electricity" + assert next(iter(data["region"])) == "R1" + assert next(iter(data["year"])) == 2010 + assert next(iter(data["value"])) == approx(14.81481) + + +def test_read_demand_csv(populate_demand): + data = populate_demand + assert np.all(data["year"] == np.array([2020, 2050])) + assert np.all(data["commodity"] == np.array(["heat", "heat"])) + assert np.all(data["region"] == np.array(["R1", "R1"])) + assert np.all(data["demand"] == np.array([10, 30])) + + +def test_read_demand_slicing_csv(populate_demand_slicing): + data = populate_demand_slicing + assert np.all(data["commodity"] == "heat") + assert np.all(data["region"] == "R1") + # assert np.all(data["timeslice"] == np.array([0, 1])) + assert np.all( + data["fraction"] + == np.array([0.1, 0.15, 0.1, 0.15, 0.3, 0.2, 0.1, 0.15, 0.1, 0.15, 0.3, 0.2]) + ) + + +def test_read_commodities_csv_type_constraint(con): + from muse.new_input.readers import read_commodities_csv + + csv = StringIO("id,type,unit\nfoo,invalid,bar\n") + with raises(duckdb.ConstraintException): + read_commodities_csv(csv, con) + + +def test_read_demand_csv_commodity_constraint( + con, populate_commodities, populate_regions +): + from muse.new_input.readers import read_demand_csv + + csv = StringIO("year,commodity_id,region_id,demand\n2020,invalid,R1,0\n") + with raises(duckdb.ConstraintException, match=".*foreign key.*"): + read_demand_csv(csv, con) + + +def test_read_demand_csv_region_constraint(con, populate_commodities, populate_regions): + from muse.new_input.readers import read_demand_csv + + csv = StringIO("year,commodity_id,region_id,demand\n2020,heat,invalid,0\n") + with raises(duckdb.ConstraintException, match=".*foreign key.*"): + read_demand_csv(csv, con) + + +def test_calculate_global_commodities(populate_commodities): + from muse.new_input.readers import calculate_global_commodities + + data = calculate_global_commodities(populate_commodities) + + assert isinstance(data, xr.Dataset) + assert set(data.dims) == {"commodity"} + for dt in data.dtypes.values(): + assert np.issubdtype(dt, np.dtype("str")) + + assert list(data.coords["commodity"].values) == list(populate_commodities["id"]) + assert list(data.data_vars["type"].values) == list(populate_commodities["type"]) + assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"]) + + +def test_calculate_demand( + populate_commodities, + populate_regions, + populate_timeslices, + populate_demand, + populate_demand_slicing, +): + from muse.new_input.readers import calculate_demand + + data = calculate_demand( + populate_commodities, + populate_regions, + populate_timeslices, + populate_demand, + populate_demand_slicing, + ) + + assert isinstance(data, xr.DataArray) + assert data.dtype == np.float64 + + assert set(data.dims) == {"year", "commodity", "region", "timeslice"} + assert set(data.coords["region"].values) == {"R1"} + assert set(data.coords["timeslice"].values) == set(range(1, 7)) + assert set(data.coords["year"].values) == {2020, 2050} + assert set(data.coords["commodity"].values) == { + "electricity", + "gas", + "heat", + "wind", + "CO2f", + } + + assert data.sel(year=2020, commodity="heat", region="R1", timeslice=1) == 1 + + +def test_calculate_initial_market( + populate_commodities, + populate_regions, + populate_timeslices, + populate_commodity_trade, + populate_commodity_costs, +): + from muse.new_input.readers import calculate_initial_market + + data = calculate_initial_market( + populate_commodities, + populate_regions, + populate_timeslices, + populate_commodity_trade, + populate_commodity_costs, + ) + + assert isinstance(data, xr.Dataset) + assert set(data.dims) == {"region", "year", "commodity", "timeslice"} + for dt in data.dtypes.values(): + assert dt == np.dtype("float64") + assert set(data.coords["region"].values) == {"R1"} + assert set(data.coords["year"].values) == set(range(2010, 2105, 5)) + assert set(data.coords["commodity"].values) == { + "electricity", + "gas", + "heat", + "CO2f", + "wind", + } + month_values = ["all-year"] * 6 + day_values = ["all-week"] * 6 + hour_values = [ + "night", + "morning", + "afternoon", + "early-peak", + "late-peak", + "evening", + ] + + assert set(data.coords["timeslice"].values) == set( + zip(month_values, day_values, hour_values) + ) + assert set(data.coords["month"].values) == set(month_values) + assert set(data.coords["day"].values) == set(day_values) + assert set(data.coords["hour"].values) == set(hour_values) + + assert all(var.coords.equals(data.coords) for var in data.data_vars.values()) + + prices = data.data_vars["prices"] + assert ( + approx( + prices.sel( + year=2010, + region="R1", + commodity="electricity", + timeslice=("all-year", "all-week", "night"), + ), + abs=1e-4, + ) + == 14.81481 + ) + + exports = data.data_vars["export"] + assert ( + exports.sel( + year=2010, + region="R1", + commodity="electricity", + timeslice=("all-year", "all-week", "night"), + ) + ) == 0 + + imports = data.data_vars["import"] + assert ( + imports.sel( + year=2010, + region="R1", + commodity="electricity", + timeslice=("all-year", "all-week", "night"), + ) + ) == 0 + + static_trade = data.data_vars["static_trade"] + assert ( + static_trade.sel( + year=2010, + region="R1", + commodity="electricity", + timeslice=("all-year", "all-week", "night"), + ) + ) == 0