|
17 | 17 | from pytensor.graph.replace import clone_replace |
18 | 18 | from pytensor.graph.rewriting.db import RewriteDatabaseQuery |
19 | 19 | from pytensor.tensor.random.basic import ( |
| 20 | + _gamma, |
20 | 21 | bernoulli, |
21 | 22 | beta, |
22 | 23 | betabinom, |
@@ -351,20 +352,31 @@ def test_lognormal_samples(mean, sigma, size): |
351 | 352 | ], |
352 | 353 | ) |
353 | 354 | def test_gamma_samples(a, b, size): |
354 | | - gamma_test_fn = fixed_scipy_rvs("gamma") |
355 | | - |
356 | | - def test_fn(shape, rate, **kwargs): |
357 | | - return gamma_test_fn(shape, scale=1.0 / rate, **kwargs) |
358 | | - |
359 | 355 | compare_sample_values( |
360 | | - gamma, |
| 356 | + _gamma, |
361 | 357 | a, |
362 | 358 | b, |
363 | 359 | size=size, |
364 | | - test_fn=test_fn, |
365 | 360 | ) |
366 | 361 |
|
367 | 362 |
|
| 363 | +def test_gamma_deprecation_wrapper_fn(): |
| 364 | + out = gamma(5.0, scale=0.5, size=(5,)) |
| 365 | + assert out.type.shape == (5,) |
| 366 | + assert out.owner.inputs[-1].eval() == 0.5 |
| 367 | + |
| 368 | + with pytest.warns(FutureWarning, match="Gamma rate argument is deprecated"): |
| 369 | + out = gamma([5.0, 10.0], 2.0, size=None) |
| 370 | + assert out.type.shape == (2,) |
| 371 | + assert out.owner.inputs[-1].eval() == 0.5 |
| 372 | + |
| 373 | + with pytest.raises(ValueError, match="Must specify scale"): |
| 374 | + gamma(5.0) |
| 375 | + |
| 376 | + with pytest.raises(ValueError, match="Cannot specify both rate and scale"): |
| 377 | + gamma(5.0, rate=2.0, scale=0.5) |
| 378 | + |
| 379 | + |
368 | 380 | @pytest.mark.parametrize( |
369 | 381 | "df, size", |
370 | 382 | [ |
@@ -470,18 +482,24 @@ def test_vonmises_samples(mu, kappa, size): |
470 | 482 |
|
471 | 483 |
|
472 | 484 | @pytest.mark.parametrize( |
473 | | - "alpha, size", |
| 485 | + "alpha, scale, size", |
474 | 486 | [ |
475 | | - (np.array(0.5, dtype=config.floatX), None), |
476 | | - (np.array(0.5, dtype=config.floatX), []), |
| 487 | + (np.array(0.5, dtype=config.floatX), np.array(3.0, dtype=config.floatX), None), |
| 488 | + (np.array(0.5, dtype=config.floatX), np.array(5.0, dtype=config.floatX), []), |
477 | 489 | ( |
478 | 490 | np.full((1, 2), 0.5, dtype=config.floatX), |
| 491 | + np.array([0.5, 1.0], dtype=config.floatX), |
479 | 492 | None, |
480 | 493 | ), |
481 | 494 | ], |
482 | 495 | ) |
483 | | -def test_pareto_samples(alpha, size): |
484 | | - compare_sample_values(pareto, alpha, size=size, test_fn=fixed_scipy_rvs("pareto")) |
| 496 | +def test_pareto_samples(alpha, scale, size): |
| 497 | + pareto_test_fn = fixed_scipy_rvs("pareto") |
| 498 | + |
| 499 | + def test_fn(shape, scale, **kwargs): |
| 500 | + return pareto_test_fn(shape, scale=scale, **kwargs) |
| 501 | + |
| 502 | + compare_sample_values(pareto, alpha, scale, size=size, test_fn=test_fn) |
485 | 503 |
|
486 | 504 |
|
487 | 505 | def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None): |
|
0 commit comments