11"""Ethereum benchmark test spec definition and filler."""
22
3- from typing import Callable , ClassVar , Dict , Generator , List , Optional , Sequence , Type
3+ from contextlib import contextmanager
4+ from contextvars import ContextVar
5+ from enum import Enum
6+ from typing import Any , Callable , ClassVar , Dict , Generator , List , Optional , Sequence , Type
47
58import pytest
6- from pydantic import Field
9+ from pydantic import ConfigDict , Field
710
811from ethereum_clis import TransitionTool
912from ethereum_test_base_types import HexNumber
2932from .blockchain import Block , BlockchainTest
3033
3134
35+ class BenchmarkPhase (Enum ):
36+ """Phases of a benchmark test."""
37+
38+ SETUP = "setup"
39+ EXECUTION = "execution"
40+
41+
42+ _current_phase : ContextVar [Optional [BenchmarkPhase ]] = ContextVar ("benchmark_phase" , default = None )
43+
44+
45+ class BenchmarkManager :
46+ """Context manager for managing benchmark test phases."""
47+
48+ def __init__ (self ):
49+ """Initialize the BenchmarkManager with empty transaction and block lists."""
50+ self .setup_transactions : List [Transaction ] = []
51+ self .setup_blocks : List [Block ] = []
52+ self .execution_transactions : List [Transaction ] = []
53+ self .execution_blocks : List [Block ] = []
54+
55+ @contextmanager
56+ def setup (self ):
57+ """Context manager for the setup phase of a benchmark test."""
58+ token = _current_phase .set (BenchmarkPhase .SETUP )
59+ try :
60+ yield self
61+ finally :
62+ _current_phase .reset (token )
63+
64+ @contextmanager
65+ def execution (self ):
66+ """Context manager for the execution phase of a benchmark test."""
67+ token = _current_phase .set (BenchmarkPhase .EXECUTION )
68+ try :
69+ yield self
70+ finally :
71+ _current_phase .reset (token )
72+
73+ def add_transaction (self , tx : Transaction ):
74+ """Add a transaction to the current phase."""
75+ current_phase = _current_phase .get ()
76+ if current_phase == BenchmarkPhase .SETUP :
77+ self .setup_transactions .append (tx )
78+ elif current_phase == BenchmarkPhase .EXECUTION :
79+ self .execution_transactions .append (tx )
80+ else :
81+ self .setup_transactions .append (tx )
82+
83+ def add_block (self , block : Block ):
84+ """Add a block to the current phase."""
85+ current_phase = _current_phase .get ()
86+ if current_phase == BenchmarkPhase .SETUP :
87+ self .setup_blocks .append (block )
88+ elif current_phase == BenchmarkPhase .EXECUTION :
89+ self .execution_blocks .append (block )
90+ else :
91+ self .setup_blocks .append (block )
92+
93+ def get_current_phase (self ) -> Optional [BenchmarkPhase ]:
94+ """Get the current benchmark phase."""
95+ return _current_phase .get ()
96+
97+
3298class BenchmarkTest (BaseTest ):
3399 """Test type designed specifically for benchmark test cases."""
34100
101+ model_config = ConfigDict (extra = "forbid" )
102+
35103 pre : Alloc
36104 post : Alloc
37105 tx : Optional [Transaction ] = None
@@ -41,6 +109,9 @@ class BenchmarkTest(BaseTest):
41109 ) = None
42110 env : Environment = Field (default_factory = Environment )
43111 expected_benchmark_gas_used : int | None = None
112+ gas_benchmark_value : int
113+ benchmark_manager : Optional [Any ] = Field (default = None , exclude = True )
114+ code_generator : Optional [Any ] = Field (default = None , exclude = True )
44115
45116 supported_fixture_formats : ClassVar [Sequence [FixtureFormat | LabeledFixtureFormat ]] = [
46117 BlockchainFixture ,
@@ -86,26 +157,81 @@ def get_genesis_environment(self, fork: Fork) -> Environment:
86157
87158 def split_transaction (self , tx : Transaction , gas_limit_cap : int | None ) -> List [Transaction ]:
88159 """Split a transaction that exceeds the gas limit cap into multiple transactions."""
89- if (gas_limit_cap is None ) or (tx .gas_limit <= gas_limit_cap ):
160+ if gas_limit_cap is None :
161+ tx .gas_limit = HexNumber (self .gas_benchmark_value )
162+ return [tx ]
163+
164+ if gas_limit_cap >= self .gas_benchmark_value :
165+ tx .gas_limit = HexNumber (min (tx .gas_limit , self .gas_benchmark_value ))
90166 return [tx ]
91167
92- total_gas = int (self .expected_benchmark_gas_used or self .env .gas_limit )
93- print (f"total_gas: { total_gas } " )
94- num_splits = total_gas // gas_limit_cap
168+ remaining_gas = self .gas_benchmark_value
169+ num_splits = remaining_gas // gas_limit_cap + int (remaining_gas % gas_limit_cap )
95170
96171 split_transactions = []
97172 for i in range (num_splits ):
98173 split_tx = tx .model_copy ()
99- total_gas -= gas_limit_cap
100- split_tx . gas_limit = HexNumber ( total_gas if i == num_splits - 1 else gas_limit_cap )
174+ split_tx . gas_limit = HexNumber ( remaining_gas if i == num_splits - 1 else gas_limit_cap )
175+ remaining_gas -= gas_limit_cap
101176 split_tx .nonce = HexNumber (tx .nonce + i )
102177 split_transactions .append (split_tx )
103178
104179 return split_transactions
105180
181+ def generate_blocks_from_code_generator (self , fork : Fork ) -> List [Block ]:
182+ """Generate blocks using the code generator."""
183+ if self .code_generator is None :
184+ return []
185+
186+ self .code_generator .deploy_contracts (self .pre )
187+ gas_limit = fork .transaction_gas_limit_cap () or self .gas_benchmark_value
188+ benchmark_tx = self .code_generator .generate_transaction (self .pre , gas_limit )
189+
190+ execution_txs = self .split_transaction (benchmark_tx , gas_limit )
191+ execution_block = Block (txs = execution_txs )
192+
193+ return [execution_block ]
194+
106195 def generate_blockchain_test (self , fork : Fork ) -> BlockchainTest :
107196 """Create a BlockchainTest from this BenchmarkTest."""
108- if self .blocks is not None :
197+ if self .code_generator is not None :
198+ generated_blocks = self .generate_blocks_from_code_generator (fork )
199+ return BlockchainTest .from_test (
200+ base_test = self ,
201+ genesis_environment = self .env ,
202+ pre = self .pre ,
203+ post = self .post ,
204+ blocks = generated_blocks ,
205+ )
206+
207+ elif self .benchmark_manager is not None :
208+ all_blocks = []
209+ gas_limit = fork .transaction_gas_limit_cap () or self .gas_benchmark_value
210+
211+ if self .benchmark_manager .setup_blocks :
212+ all_blocks .extend (self .benchmark_manager .setup_blocks )
213+ elif self .benchmark_manager .setup_transactions :
214+ setup_txs = []
215+ for tx in self .benchmark_manager .setup_transactions :
216+ setup_txs .extend (self .split_transaction (tx , gas_limit ))
217+ all_blocks .append (Block (txs = setup_txs ))
218+
219+ if self .benchmark_manager .execution_blocks :
220+ all_blocks .extend (self .benchmark_manager .execution_blocks )
221+ elif self .benchmark_manager .execution_transactions :
222+ execution_txs = []
223+ for tx in self .benchmark_manager .execution_transactions :
224+ execution_txs .extend (self .split_transaction (tx , gas_limit ))
225+ all_blocks .append (Block (txs = execution_txs ))
226+
227+ return BlockchainTest .from_test (
228+ base_test = self ,
229+ genesis_environment = self .env ,
230+ pre = self .pre ,
231+ post = self .post ,
232+ blocks = all_blocks ,
233+ )
234+ elif self .blocks is not None :
109235 return BlockchainTest .from_test (
110236 base_test = self ,
111237 genesis_environment = self .env ,
@@ -114,9 +240,9 @@ def generate_blockchain_test(self, fork: Fork) -> BlockchainTest:
114240 blocks = self .blocks ,
115241 )
116242 elif self .tx is not None :
117- gas_limit_cap = fork .transaction_gas_limit_cap ()
243+ gas_limit = fork .transaction_gas_limit_cap () or self . gas_benchmark_value
118244
119- transactions = self .split_transaction (self .tx , gas_limit_cap )
245+ transactions = self .split_transaction (self .tx , gas_limit )
120246
121247 blocks = [Block (txs = transactions )]
122248
@@ -129,7 +255,7 @@ def generate_blockchain_test(self, fork: Fork) -> BlockchainTest:
129255 )
130256 else :
131257 raise ValueError (
132- "Cannot create BlockchainTest without transactions, blocks, or code_generator "
258+ "Cannot create BlockchainTest without transactions, blocks, or benchmark_manager "
133259 )
134260
135261 def generate (
@@ -162,5 +288,10 @@ def execute(
162288 raise Exception (f"Unsupported execute format: { execute_format } " )
163289
164290
291+ def create_benchmark_manager () -> BenchmarkManager :
292+ """Create a new BenchmarkManager instance for phase-aware benchmark testing."""
293+ return BenchmarkManager ()
294+
295+
165296BenchmarkTestSpec = Callable [[str ], Generator [BenchmarkTest , None , None ]]
166297BenchmarkTestFiller = Type [BenchmarkTest ]
0 commit comments