Skip to content

"warmup_sample_stats" not saved in ClickHouse backend #93

@thelogicalgrammar

Description

@thelogicalgrammar

I am running the following while running a local ClickHouse server (PyMC v.5.3.0):

import arviz
import numpy as np
import pymc as pm
import mcbackend as mcb
from clickhouse_driver import Client


def define_simple_model():
    seconds = np.linspace(0, 5)
    observations = np.random.normal(0.5 + np.random.uniform(size=3)[:, None] * seconds[None, :])
    with pm.Model(
        coords={
            "condition": ["A", "B", "C"],
        }
    ) as pmodel:
        x = pm.ConstantData("seconds", seconds, dims="time")
        a = pm.Normal("scalar")
        b = pm.Uniform("vector", dims="condition")
        pm.Deterministic("matrix", a + b[:, None] * x[None, :], dims=("condition", "time"))
        obs = pm.MutableData("obs", observations, dims=("condition", "time"))
        pm.Normal("L", pmodel["matrix"], observed=obs, dims=("condition", "time"))
        
    return pmodel


if __name__=='__main__':
    
    simple_model = define_simple_model()
    backend = mcb.NumPyBackend()
    with simple_model:
        trace = pm.sample(
            trace=backend
        )
        
    print(trace)

This raises the following error:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [scalar, vector]
Traceback (most recent call last):████████████████████████████████████████████████████████████| 100.00% [8000/8000 00:01<00:00 Sampling 4 chains, 0 divergences]
  File "/mnt/c/Users/faust/Dropbox/Tubingen/joint_learning/model/test_clickhouse_backend.py", line 35, in <module>
    trace = pm.sample(
            ^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 702, in sample
    return _sample_return(
           ^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 736, in _sample_return
    mtrace = MultiTrace(traces)[:length]
             ~~~~~~~~~~~~~~~~~~^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/base.py", line 370, in __getitem__
    return self._slice(idx)
           ^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/base.py", line 537, in _slice
    new_traces = [trace._slice(slice) for trace in self._straces.values()]
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/base.py", line 537, in <listcomp>
    new_traces = [trace._slice(slice) for trace in self._straces.values()]
                  ^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/mcbackend.py", line 194, in _slice
    stats = self._chain.get_stats_at(i, stat_names=snames)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/numpy.py", line 106, in get_stats_at
    return {sn: numpy.asarray(self._stats[sn][idx]) for sn in stat_names}
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/numpy.py", line 106, in <dictcomp>
    return {sn: numpy.asarray(self._stats[sn][idx]) for sn in stat_names}
                              ~~~~~~~~~~~~~~~^^^^^
IndexError: index 1000 is out of bounds for axis 0 with size 1000

Ultimately, I am interested in running with the ClickHouse backend. When I replace the relevant part:

if __name__=='__main__':
    
    simple_model = define_simple_model()
    
    ch_client = Client("localhost")
    # Check that it is defined properly
    print(ch_client.execute('SHOW DATABASES'))
    backend = mcb.ClickHouseBackend(ch_client)
    
    with simple_model:
        trace = pm.sample(
            trace=backend
        )
        
    print(trace)

another error is raised:

[('INFORMATION_SCHEMA',), ('default',), ('information_schema',), ('system',)]
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Dimensionality of sample stat 'sampler_0__warning' is undefined. Assuming ndim=0.
Traceback (most recent call last):
  File "/mnt/c/Users/faust/Dropbox/Tubingen/joint_learning/model/test_clickhouse_backend.py", line 37, in <module>
    trace = pm.sample(
            ^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 623, in sample
    run, traces = init_traces(
                  ^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/__init__.py", line 127, in init_traces
    return init_chain_adapters(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/mcbackend.py", line 278, in init_chain_adapters
    adapters = [
               ^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/backends/mcbackend.py", line 280, in <listcomp>
    chain=run.init_chain(chain_number=chain_number),
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/clickhouse.py", line 325, in init_chain
    create_chain_table(self._client, cmeta, self.meta)
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/clickhouse.py", line 98, in create_chain_table
    columns.append(column_spec_for(var, is_stat=True))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/mcbackend/backends/clickhouse.py", line 65, in column_spec_for
    raise KeyError(
KeyError: "Don't know how to store dtype object of 'sampler_0__warning' (is_stat=True) in ClickHouse."

I "fixed" this error by adding the following lines in "mcbackend/backends/clickhouse.py" on line 101:

if var.dtype=='object':
    var.dtype='str'

As it turns out, this only affects the stats "sampler_0__warning":

Variable(name='tune', dtype='bool', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__depth', dtype='int64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__step_size', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__tune', dtype='bool', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__mean_tree_accept', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__step_size_bar', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__tree_size', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__diverging', dtype='bool', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__energy_error', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__energy', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__max_energy_error', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__model_logp', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__process_time_diff', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__perf_counter_diff', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__perf_counter_start', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__largest_eigval', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__smallest_eigval', dtype='float64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__index_in_trajectory', dtype='int64', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__reached_max_treedepth', dtype='bool', shape=[], dims=[], is_deterministic=False, undefined_ndim=False)
True 


Variable(name='sampler_0__warning', dtype='object', shape=[], dims=[], is_deterministic=False, undefined_ndim=True)
False 

I ran the code above a few time with the ClickHouse backend (with the fix to mcbackend). Then I ran this code to check that the traces were being saved in the database:

ch_client = Client("localhost")
backend = mcbackend.ClickHouseBackend(ch_client)

print("list of runs:\n", backend.get_runs())

# Fetch a single run from the database (downloads just metadata)
run = backend.get_run("6F6PW")

# Convert everything to `InferenceData`
idata = run.to_inferencedata()

print(idata)

Which, as expected, prints out the following:

list of runs:
                             created_at   
rid                                      
99PHY 2023-05-19 02:54:46.005139+00:00  \
TREHW 2023-05-19 02:56:15.987728+00:00   
1PQK6 2023-05-19 02:56:56.700751+00:00   
YKT3N 2023-05-19 03:34:37.115745+00:00   
WLKW4 2023-05-19 03:43:04.277188+00:00   
9HHAC 2023-05-19 03:49:53.529272+00:00   
1P93X 2023-05-19 03:51:50.767925+00:00   
Q8LLA 2023-05-19 03:54:46.947725+00:00   
DPP94 2023-05-19 03:56:24.189411+00:00   
9CFX4 2023-05-19 03:58:19.232365+00:00   
C3AZB 2023-05-19 04:01:53.742057+00:00   
A36Y7 2023-05-19 04:05:15.148334+00:00   
M4LND 2023-05-19 04:06:23.221856+00:00   
6F6PW 2023-05-19 04:44:46.898726+00:00   

                                                   proto  
rid                                                       
99PHY  RunMeta(rid='99PHY', variables=[Variable(name=...  
TREHW  RunMeta(rid='TREHW', variables=[Variable(name=...  
1PQK6  RunMeta(rid='1PQK6', variables=[Variable(name=...  
YKT3N  RunMeta(rid='YKT3N', variables=[Variable(name=...  
WLKW4  RunMeta(rid='WLKW4', variables=[Variable(name=...  
9HHAC  RunMeta(rid='9HHAC', variables=[Variable(name=...  
1P93X  RunMeta(rid='1P93X', variables=[Variable(name=...  
Q8LLA  RunMeta(rid='Q8LLA', variables=[Variable(name=...  
DPP94  RunMeta(rid='DPP94', variables=[Variable(name=...  
9CFX4  RunMeta(rid='9CFX4', variables=[Variable(name=...  
C3AZB  RunMeta(rid='C3AZB', variables=[Variable(name=...  
A36Y7  RunMeta(rid='A36Y7', variables=[Variable(name=...  
M4LND  RunMeta(rid='M4LND', variables=[Variable(name=...  
6F6PW  RunMeta(rid='6F6PW', variables=[Variable(name=...  
Inference data with groups:
	> posterior
	> sample_stats
	> observed_data
	> constant_data

Warmup iterations saved (warmup_*).
/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/arviz/data/base.py:221: UserWarning: More chains (3) than draws (0). Passed array should have shape (chains, draws, *shape)
  warnings.warn(
/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/arviz/data/base.py:221: UserWarning: More chains (3) than draws (0). Passed array should have shape (chains, draws, *shape)
  warnings.warn(
/home/fausto/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/arviz/data/base.py:221: UserWarning: More chains (3) than draws (0). Passed array should have shape (chains, draws, *shape)
  warnings.warn(

However, this still has a problem, namely all the fields saved in the "warmup_sample_stats" group of the idata are empty.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions