Skip to content

Commit 6eaec8f

Browse files
committed
Save more data to .json in preparation of other model kinds
1 parent 7edaab4 commit 6eaec8f

File tree

8 files changed

+81
-56
lines changed

8 files changed

+81
-56
lines changed

src/coreclr/scripts/cse_ml/evaluate.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pandas
1010
import tqdm
1111

12-
from jitml import SuperPmi, SuperPmiContext, JitCseModel, MethodContext, JitCseEnv
12+
from jitml import SuperPmi, SuperPmiContext, JitCseModel, MethodContext, JitCseEnv, split_for_cse
1313
from train import validate_core_root
1414

1515
class ModelResult(Enum):
@@ -109,9 +109,6 @@ def test_model(superpmi : SuperPmi, jitrl : JitCseModel, method_ids, model_name)
109109

110110
def evaluate(superpmi, jitrl, methods, model_name, csv_file) -> pandas.DataFrame:
111111
"""Evaluate the model and save to the specified CSV file."""
112-
print(csv_file)
113-
print(model_name)
114-
print(len(methods))
115112
if os.path.exists(csv_file):
116113
return pandas.read_csv(csv_file)
117114

@@ -200,10 +197,11 @@ def main(args):
200197
spmi_context = SuperPmiContext.load(spmi_file)
201198
else:
202199
print(f"Creating SuperPmiContext '{spmi_file}', this may take several minutes...")
203-
spmi_context = SuperPmiContext(core_root=args.core_root, mch=args.mch)
204-
spmi_context.find_methods_and_split(0.1)
200+
spmi_context = SuperPmiContext.create_from_mch(args.mch, args.core_root)
205201
spmi_context.save(spmi_file)
206202

203+
test_methods, training_methods = split_for_cse(spmi_context.methods, 0.1)
204+
207205
for file in enumerate_models(dir_or_path):
208206
print(file)
209207
with spmi_context.create_superpmi() as superpmi:
@@ -216,11 +214,11 @@ def main(args):
216214
model_name = os.path.splitext(file)[0]
217215

218216
filename = os.path.join(dir_or_path, f"{model_name}_test.csv")
219-
result = evaluate(superpmi, jitrl, spmi_context.test_methods, model_name, filename)
217+
result = evaluate(superpmi, jitrl, test_methods, model_name, filename)
220218
print_result(result, model_name, "Test")
221219

222220
filename = os.path.join(dir_or_path, f"{model_name}_train.csv")
223-
result = evaluate(superpmi, jitrl, spmi_context.training_methods, model_name, filename)
221+
result = evaluate(superpmi, jitrl, training_methods, model_name, filename)
224222
print_result(result, model_name, "Train")
225223

226224
if __name__ == "__main__":

src/coreclr/scripts/cse_ml/jitml/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .jit_cse import JitCseEnv
55
from .machine_learning import JitCseModel
66
from .wrappers import OptimalCseWrapper, NormalizeFeaturesWrapper
7+
from .constants import is_acceptable_for_cse, split_for_cse
78

89
__all__ = [
910
SuperPmi.__name__,
@@ -15,4 +16,6 @@
1516
JitType.__name__,
1617
OptimalCseWrapper.__name__,
1718
NormalizeFeaturesWrapper.__name__,
19+
is_acceptable_for_cse.__name__,
20+
split_for_cse.__name__,
1821
]
Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,50 @@
11
"""Constants and parameters for the project."""
22

3+
from typing import Sequence
4+
5+
import numpy as np
6+
from .method_context import MethodContext
7+
38
MIN_CSE = 3
49
MAX_CSE = 16
510

611
INVALID_ACTION_PENALTY = -0.05
712
INVALID_ACTION_LIMIT = 20
813

9-
def is_acceptable_method(method):
10-
"""Returns True if the method is acceptable for training."""
14+
def is_acceptable_for_cse(method):
15+
"""Returns True if the method is acceptable for training on JitCseEnv."""
1116
applicable = len([x for x in method.cse_candidates if x.viable])
1217
return MIN_CSE <= applicable and len(method.cse_candidates) <= MAX_CSE
18+
19+
def split_for_cse(methods : Sequence['MethodContext'], test_percent=0.1):
20+
"""Splits the methods into those that can be used for training and those that can't.
21+
Returns the test and train sets."""
22+
method_by_cse = {}
23+
24+
for x in methods:
25+
if is_acceptable_for_cse(x):
26+
method_by_cse.setdefault(x.num_cse, []).append(x)
27+
28+
# convert method_by_cse to a list of methods
29+
methods_list = []
30+
for value in method_by_cse.values():
31+
methods_list.append(value)
32+
33+
test = []
34+
train = []
35+
36+
# use a fixed seed so subsequent calls line up
37+
# Sort the groups of methods by length to ensure we don't care what order we process them in.
38+
# Then sort each method by id before shuffling to (again) ensure we get the same result.
39+
methods_list.sort(key=len)
40+
for method_group in methods_list:
41+
split = int(len(method_group) * test_percent)
42+
43+
# Discard any groups that are too small to split.
44+
if split > 0:
45+
method_group.sort(key=lambda x: x.index)
46+
np.random.default_rng(seed=42).shuffle(method_group)
47+
test.extend(method_group[:split])
48+
train.extend(method_group[split:])
49+
50+
return test, train

src/coreclr/scripts/cse_ml/jitml/jit_cse.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from .method_context import JitType, MethodContext
88
from .superpmi import SuperPmi, SuperPmiContext
9-
from .constants import (INVALID_ACTION_PENALTY, INVALID_ACTION_LIMIT, MAX_CSE, is_acceptable_method)
9+
from .constants import (INVALID_ACTION_PENALTY, INVALID_ACTION_LIMIT, MAX_CSE, is_acceptable_for_cse)
1010

1111
# observation space
1212
JITTYPE_ONEHOT_SIZE = 6
@@ -60,13 +60,8 @@ def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = Non
6060
while True:
6161
index = self.__select_method()
6262
no_cse = self._jit_method_with_cleanup(index, JitMetrics=1, JitRLHook=1, JitRLHookCSEDecisions=[])
63-
if no_cse is None:
64-
continue
65-
66-
if is_acceptable_method(no_cse):
67-
original_heuristic = self._jit_method_with_cleanup(index, JitMetrics=1)
68-
if original_heuristic is None:
69-
continue
63+
original_heuristic = self._jit_method_with_cleanup(index, JitMetrics=1)
64+
if no_cse and original_heuristic:
7065
break
7166

7267
failure_count += 1
@@ -184,7 +179,7 @@ def get_observation(cls, method : MethodContext, fill=True):
184179

185180
# one-hot encode the type
186181
one_hot = [0.0] * 6
187-
one_hot[cse.type.value - 1] = 1.0
182+
one_hot[cse.type - 1] = 1.0
188183
tensor.extend(one_hot)
189184

190185
# boolean features
@@ -227,7 +222,7 @@ def _jit_method_with_cleanup(self, m_id, *args, **kwargs):
227222
def __select_method(self):
228223
if self.methods is None:
229224
superpmi = self.__get_or_create_superpmi()
230-
self.methods = [x.index for x in superpmi.enumerate_methods() if is_acceptable_method(x)]
225+
self.methods = [x.index for x in superpmi.enumerate_methods() if is_acceptable_for_cse(x)]
231226

232227
return np.random.choice(self.methods)
233228

src/coreclr/scripts/cse_ml/jitml/machine_learning.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from stable_baselines3.common.vec_env import SubprocVecEnv
1818
import gymnasium as gym
1919

20+
from .method_context import MethodContext
2021
from .jit_cse import JitCseEnv
2122
from .superpmi import SuperPmiContext
2223

@@ -60,12 +61,14 @@ def action_probabilities(self, obs):
6061
probs = action_distribution.distribution.probs
6162
return probs.cpu().detach().numpy()[0]
6263

63-
def train(self, pmi_context : SuperPmiContext, output_dir : str, iterations = None, parallel = None,
64-
progress_bar = True, wrappers : Optional[List[gym.Wrapper]] = None) -> str:
64+
def train(self, pmi_context : SuperPmiContext, training_methods : List[MethodContext], output_dir : str,
65+
iterations = None, parallel = None, progress_bar = True,
66+
wrappers : Optional[List[gym.Wrapper]] = None) -> str:
6567
"""Trains a model from scratch.
6668
6769
Args:
6870
pmi_context: The SuperPmiContext to use for training.
71+
training_methods : The methods to train on.
6972
output_dir: The directory to save the model to.
7073
iterations: The number of iterations to train for. Defaults to 100,000.
7174
parallel: The number of parallel environments to use. Defaults to single-process (None).
@@ -74,10 +77,11 @@ def train(self, pmi_context : SuperPmiContext, output_dir : str, iterations = No
7477
Returns:
7578
The full path to the trained model.
7679
"""
80+
training_methods = [m.index for m in training_methods]
7781
os.makedirs(output_dir, exist_ok=True)
7882

7983
def default_make_env():
80-
env = JitCseEnv(pmi_context, pmi_context.training_methods)
84+
env = JitCseEnv(pmi_context, training_methods)
8185
if wrappers:
8286
for wrapper in wrappers:
8387
env = wrapper(env)

src/coreclr/scripts/cse_ml/jitml/method_context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class CseCandidate(BaseModel):
2626
make_cse : bool
2727
has_call : bool
2828
containable : bool
29-
type : JitType
29+
type : int
3030
cost_ex : int
3131
cost_sz : int
3232
use_count : int
@@ -38,6 +38,7 @@ class CseCandidate(BaseModel):
3838
bb_count : int
3939
block_spread : int
4040
enreg_count : int
41+
for_testing : Optional[bool] = False
4142

4243
@field_validator('applied', 'viable', 'live_across_call', 'const', 'shared_const', 'make_cse', 'has_call',
4344
'containable', mode='before')

src/coreclr/scripts/cse_ml/jitml/superpmi.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
import subprocess
66
import re
77
from typing import Iterable, List, Optional
8-
import numpy as np
98
from pydantic import BaseModel, field_validator
109

11-
from .constants import is_acceptable_method
1210
from .method_context import MethodContext
1311

1412
class SuperPmiContext(BaseModel):
@@ -18,8 +16,7 @@ class SuperPmiContext(BaseModel):
1816
core_root : str
1917
mch : str
2018
jit : Optional[str] = None
21-
test_methods : Optional[List[int]] = []
22-
training_methods : Optional[List[int]] = []
19+
methods : Optional[List[MethodContext]] = []
2320

2421
@field_validator('core_root', 'mch', mode='before')
2522
@classmethod
@@ -37,37 +34,24 @@ def _validate_optional_path(cls, v):
3734

3835
return v
3936

40-
def resplit_data(self, test_percent:float):
41-
"""Splits the data into training and testing sets."""
42-
if not self.test_methods and not self.training_methods:
43-
raise ValueError("No methods to split. Try calling 'find_methods_and_split' first.")
44-
45-
all_methods = self.test_methods + self.training_methods
46-
np.random.shuffle(all_methods)
47-
self.test_methods = all_methods[:int(len(all_methods) * test_percent)]
48-
self.training_methods = all_methods[len(self.test_methods):]
49-
50-
def find_methods_and_split(self, test_percent:float) -> None:
37+
@staticmethod
38+
def create_from_mch(mch : str, core_root : str, jit : Optional[str] = None) -> 'SuperPmiContext':
5139
"""Loads the SuperPmiContext from the specified arguments."""
52-
suitable_methods = []
53-
with SuperPmi(self) as superpmi:
40+
result = SuperPmiContext(core_root=core_root, mch=mch, jit=jit)
41+
42+
methods = []
43+
with SuperPmi(result) as superpmi:
5444
for method in superpmi.enumerate_methods():
55-
if is_acceptable_method(method):
56-
suitable_methods.append(method.index)
45+
methods.append(method)
5746

58-
self.test_methods = suitable_methods
59-
self.resplit_data(test_percent)
47+
result.methods = methods
48+
return result
6049

6150
def save(self, file_path:str):
6251
"""Saves the SuperPmiContext to a file."""
6352
with open(file_path, 'w', encoding="utf8") as f:
6453
json.dump(self.model_dump(), f)
6554

66-
67-
def create_superpmi(self, verbosity:str = 'q'):
68-
"""Creates a SuperPmi object from this context."""
69-
return SuperPmi(self, verbosity)
70-
7155
@staticmethod
7256
def load(file_path:str):
7357
"""Loads the SuperPmiContext from a file."""
@@ -78,13 +62,15 @@ def load(file_path:str):
7862
data = json.load(f)
7963
return SuperPmiContext(**data)
8064

65+
def create_superpmi(self, verbosity:str = 'q'):
66+
"""Creates a SuperPmi object from this context."""
67+
return SuperPmi(self, verbosity)
8168

8269
class SuperPmi:
8370
"""Controls one instance of superpmi."""
8471
def __init__(self, context : SuperPmiContext, verbosity:str = 'q'):
8572
"""Constructor.
8673
core_root is the path to the coreclr build, usually at [repo]/artifiacts/bin/coreclr/[arch]/.
87-
jit is the full path to the jit to use. Default is None.
8874
verbosity is the verbosity level of the superpmi process. Default is 'q'."""
8975
self._process = None
9076
self._feature_names = None

src/coreclr/scripts/cse_ml/train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import argparse
66

7-
from jitml import SuperPmiContext, JitCseModel, OptimalCseWrapper, NormalizeFeaturesWrapper
7+
from jitml import SuperPmiContext, JitCseModel, OptimalCseWrapper, NormalizeFeaturesWrapper, split_for_cse
88

99
def validate_core_root(core_root):
1010
"""Validates and returns the core_root directory."""
@@ -44,11 +44,11 @@ def main(args):
4444
ctx = SuperPmiContext.load(spmi_file)
4545
else:
4646
print(f"Creating SuperPmiContext '{spmi_file}', this may take several minutes...")
47-
ctx = SuperPmiContext(core_root=args.core_root, mch=args.mch)
48-
ctx.find_methods_and_split(args.test_percent)
47+
ctx = SuperPmiContext.create_from_mch(args.mch, args.core_root)
4948
ctx.save(spmi_file)
5049

51-
print(f"Training with {len(ctx.training_methods)} methods, holding back {len(ctx.test_methods)} for testing.")
50+
test_methods, training_methods = split_for_cse(ctx.methods, 0.1)
51+
print(f"Training with {len(training_methods)} methods, holding back {len(test_methods)} for testing.")
5252

5353
# Define our own environment (with wrappers) if requested.
5454

@@ -63,7 +63,7 @@ def main(args):
6363
wrappers.append(NormalizeFeaturesWrapper)
6464

6565
iterations = args.iterations if args.iterations is not None else 1_000_000
66-
path = rl.train(ctx, output_dir, iterations=iterations, parallel=args.parallel, wrappers=wrappers)
66+
path = rl.train(ctx, training_methods, output_dir, iterations=iterations, parallel=args.parallel, wrappers=wrappers)
6767
print(f"Model saved to: {path}")
6868

6969
if __name__ == "__main__":

0 commit comments

Comments
 (0)