diff --git a/src/coreclr/jit/compiler.h b/src/coreclr/jit/compiler.h index c40d7655f5f32b..963acd9a3e82ff 100644 --- a/src/coreclr/jit/compiler.h +++ b/src/coreclr/jit/compiler.h @@ -2541,6 +2541,7 @@ class Compiler friend class CSE_HeuristicReplay; friend class CSE_HeuristicRL; friend class CSE_HeuristicParameterized; + friend class CSE_HeuristicRLHook; friend class CSE_Heuristic; friend class CodeGenInterface; friend class CodeGen; diff --git a/src/coreclr/jit/jitconfigvalues.h b/src/coreclr/jit/jitconfigvalues.h index d2ea1deca5166a..aa645c5b4d7411 100644 --- a/src/coreclr/jit/jitconfigvalues.h +++ b/src/coreclr/jit/jitconfigvalues.h @@ -498,6 +498,15 @@ CONFIG_STRING(JitRLCSEAlpha, W("JitRLCSEAlpha")) // If nonzero, dump candidate feature values CONFIG_INTEGER(JitRLCSECandidateFeatures, W("JitRLCSECandidateFeatures"), 0) +// Enable CSE_HeuristicRLHook +CONFIG_INTEGER(JitRLHook, W("JitRLHook"), 0) // If 1, emit RL callbacks + +// If 1, emit feature column names +CONFIG_INTEGER(JitRLHookEmitFeatureNames, W("JitRLHookEmitFeatureNames"), 0) + +// A list of CSEs to choose, in the order they should be applied. +CONFIG_STRING(JitRLHookCSEDecisions, W("JitRLHookCSEDecisions")) + #if !defined(DEBUG) && !defined(_DEBUG) RELEASE_CONFIG_INTEGER(JitEnableNoWayAssert, W("JitEnableNoWayAssert"), 0) #else // defined(DEBUG) || defined(_DEBUG) diff --git a/src/coreclr/jit/optcse.cpp b/src/coreclr/jit/optcse.cpp index 55533e306d8f40..b3d8e2610168a1 100644 --- a/src/coreclr/jit/optcse.cpp +++ b/src/coreclr/jit/optcse.cpp @@ -2973,6 +2973,298 @@ void CSE_HeuristicParameterized::DumpChoices(ArrayStack& choices, CSEdsc #ifdef DEBUG +//------------------------------------------------------------------------ +// CSE_HeuristicRLHook: a generic 'hook' for driving CSE decisions out of +// process using reinforcement learning +// +// Arguments; +// pCompiler - compiler instance +// +// Notes: +// This creates a hook to control CSE decisions from an external process +// when JitRLHook=1 is set. This will cause the JIT to emit a series of +// feature building blocks for each CSE in the method. Feature names for +// these values can be found by setting JitRLHookEmitFeatureNames=1. To +// control the CSE decisions, set JitRLHookCSEDecisions with a sequence +// of CSE indices to apply. +// +// This hook is only available in debug/checked builds, and does not +// contain any machine learning code. +// +CSE_HeuristicRLHook::CSE_HeuristicRLHook(Compiler* pCompiler) + : CSE_HeuristicCommon(pCompiler) +{ +} + +//------------------------------------------------------------------------ +// ConsiderTree: check if this tree can be a CSE candidate +// +// Arguments: +// tree - tree in question +// isReturn - true if tree is part of a return statement +// +// Returns: +// true if this tree can be a CSE +bool CSE_HeuristicRLHook::ConsiderTree(GenTree* tree, bool isReturn) +{ + return CanConsiderTree(tree, isReturn); +} + +//------------------------------------------------------------------------ +// ConsiderCandidates: examine candidates and perform CSEs. +// This simply defers to the JitRLHookCSEDecisions config value. +// +void CSE_HeuristicRLHook::ConsiderCandidates() +{ + if (JitConfig.JitRLHookCSEDecisions() != nullptr) + { + ConfigIntArray JitRLHookCSEDecisions; + JitRLHookCSEDecisions.EnsureInit(JitConfig.JitRLHookCSEDecisions()); + + unsigned cnt = m_pCompiler->optCSECandidateCount; + for (unsigned i = 0; i < JitRLHookCSEDecisions.GetLength(); i++) + { + const int index = JitRLHookCSEDecisions.GetData()[i]; + if ((index < 0) || (index >= (int)cnt)) + { + JITDUMP("Invalid candidate number %d\n", index + 1); + continue; + } + + CSEdsc* const dsc = m_pCompiler->optCSEtab[index]; + if (!dsc->IsViable()) + { + JITDUMP("Abandoned " FMT_CSE " -- not viable\n", dsc->csdIndex); + continue; + } + + const int attempt = m_pCompiler->optCSEattempt++; + CSE_Candidate candidate(this, dsc); + + JITDUMP("\nRLHook attempting " FMT_CSE "\n", candidate.CseIndex()); + JITDUMP("CSE Expression : \n"); + JITDUMPEXEC(m_pCompiler->gtDispTree(candidate.Expr())); + JITDUMP("\n"); + + PerformCSE(&candidate); + madeChanges = true; + } + } +} + +//------------------------------------------------------------------------ +// DumpMetrics: write out features for each CSE candidate +// Format: +// featureNames +// features #, +// seq +// +// Notes: +// featureNames are emitted only if JitRLHookEmitFeatureNames is set. +// features are 0 indexed, and the index is the first value, following #. +// seq is a comma separated list of CSE indices that were applied, or +// omitted if none were selected +// +void CSE_HeuristicRLHook::DumpMetrics() +{ + // Feature names, if requested + if (JitConfig.JitRLHookEmitFeatureNames() > 0) + { + printf(" featureNames "); + for (int i = 0; i < maxFeatures; i++) + { + printf("%s%s", (i == 0) ? "" : ",", s_featureNameAndType[i]); + } + } + + // features + for (unsigned i = 0; i < m_pCompiler->optCSECandidateCount; i++) + { + CSEdsc* const cse = m_pCompiler->optCSEtab[i]; + + int features[maxFeatures]; + GetFeatures(cse, features); + + printf(" features #%i", cse->csdIndex); + for (int j = 0; j < maxFeatures; j++) + { + printf(",%d", features[j]); + } + } + + // The selected sequence of CSEs that were applied + if (JitConfig.JitRLHookCSEDecisions() != nullptr) + { + ConfigIntArray JitRLHookCSEDecisions; + JitRLHookCSEDecisions.EnsureInit(JitConfig.JitRLHookCSEDecisions()); + + if (JitRLHookCSEDecisions.GetLength() > 0) + { + printf(" seq "); + for (unsigned i = 0; i < JitRLHookCSEDecisions.GetLength(); i++) + { + printf("%s%d", (i == 0) ? "" : ",", JitRLHookCSEDecisions.GetData()[i]); + } + } + } +} + +//------------------------------------------------------------------------ +// GetFeatures: extract features for this CSE +// Arguments: +// cse - cse descriptor +// features - array to fill in with feature values, this must be of length +// maxFeatures or greater +// +// Notes: +// Features are intended to be building blocks of "real" features that +// are further defined and refined in the machine learning model. That +// means that each "feature" here is a simple value and not a composite +// of multiple values. +// +// Features do not need to be stable across builds, they can be changed, +// added, or removed. However, the corresponding code needs to be updated +// to match: src/coreclr/scripts/cse_ml/jitml/method_context.py +// See src/coreclr/scripts/cse_ml/README.md for more information. +// +void CSE_HeuristicRLHook::GetFeatures(CSEdsc* cse, int* features) +{ + assert(cse != nullptr); + assert(features != nullptr); + CSE_Candidate candidate(this, cse); + + int enregCount = 0; + for (unsigned trackedIndex = 0; trackedIndex < m_pCompiler->lvaTrackedCount; trackedIndex++) + { + LclVarDsc* varDsc = m_pCompiler->lvaGetDescByTrackedIndex(trackedIndex); + var_types varTyp = varDsc->TypeGet(); + + // Locals with no references aren't enregistered + if (varDsc->lvRefCnt() == 0) + { + continue; + } + + // Some LclVars always have stack homes + if (varDsc->lvDoNotEnregister) + { + continue; + } + + if (!varTypeIsFloating(varTyp)) + { + enregCount++; // The primitive types, including TYP_SIMD types use one register + +#ifndef TARGET_64BIT + if (varTyp == TYP_LONG) + { + enregCount++; // on 32-bit targets longs use two registers + } +#endif + } + } + + const unsigned numBBs = m_pCompiler->fgBBcount; + bool isMakeCse = false; + unsigned minPostorderNum = numBBs; + unsigned maxPostorderNum = 0; + BasicBlock* minPostorderBlock = nullptr; + BasicBlock* maxPostorderBlock = nullptr; + for (treeStmtLst* treeList = cse->csdTreeList; treeList != nullptr; treeList = treeList->tslNext) + { + BasicBlock* const treeBlock = treeList->tslBlock; + unsigned postorderNum = treeBlock->bbPostorderNum; + if (postorderNum < minPostorderNum) + { + minPostorderNum = postorderNum; + minPostorderBlock = treeBlock; + } + + if (postorderNum > maxPostorderNum) + { + maxPostorderNum = postorderNum; + maxPostorderBlock = treeBlock; + } + + isMakeCse |= ((treeList->tslTree->gtFlags & GTF_MAKE_CSE) != 0); + } + + const unsigned blockSpread = maxPostorderNum - minPostorderNum; + + int type = rlHookTypeOther; + if (candidate.Expr()->TypeIs(TYP_INT)) + { + type = rlHookTypeInt; + } + else if (candidate.Expr()->TypeIs(TYP_LONG)) + { + type = rlHookTypeLong; + } + else if (candidate.Expr()->TypeIs(TYP_FLOAT)) + { + type = rlHookTypeFloat; + } + else if (candidate.Expr()->TypeIs(TYP_DOUBLE)) + { + type = rlHookTypeDouble; + } + else if (candidate.Expr()->TypeIs(TYP_STRUCT)) + { + type = rlHookTypeStruct; + } + +#ifdef FEATURE_SIMD + else if (varTypeIsSIMD(candidate.Expr()->TypeGet())) + { + type = rlHookTypeSimd; + } +#ifdef TARGET_XARCH + else if (candidate.Expr()->TypeIs(TYP_SIMD32, TYP_SIMD64)) + { + type = rlHookTypeSimd; + } +#endif +#endif + + int i = 0; + features[i++] = type; + features[i++] = cse->IsViable() ? 1 : 0; + features[i++] = cse->csdLiveAcrossCall ? 1 : 0; + features[i++] = cse->csdTree->OperIsConst() ? 1 : 0; + features[i++] = cse->csdIsSharedConst ? 1 : 0; + features[i++] = isMakeCse ? 1 : 0; + features[i++] = ((cse->csdTree->gtFlags & GTF_CALL) != 0) ? 1 : 0; + features[i++] = cse->csdTree->OperIs(GT_ADD, GT_NOT, GT_MUL, GT_LSH) ? 1 : 0; + features[i++] = cse->csdTree->GetCostEx(); + features[i++] = cse->csdTree->GetCostSz(); + features[i++] = cse->csdUseCount; + features[i++] = cse->csdDefCount; + features[i++] = (int)cse->csdUseWtCnt; + features[i++] = (int)cse->csdDefWtCnt; + features[i++] = cse->numDistinctLocals; + features[i++] = cse->numLocalOccurrences; + features[i++] = numBBs; + features[i++] = blockSpread; + features[i++] = enregCount; + + assert(i <= maxFeatures); + + for (; i < maxFeatures; i++) + { + features[i] = 0; + } +} + +// These need to match the features above, and match the field name of MethodContext +// in src/coreclr/scripts/cse_ml/jitml/method_context.py +const char* const CSE_HeuristicRLHook::s_featureNameAndType[] = { + "type", "viable", "live_across_call", "const", + "shared_const", "make_cse", "has_call", "containable", + "cost_ex", "cost_sz", "use_count", "def_count", + "use_wt_cnt", "def_wt_cnt", "distinct_locals", "local_occurrences", + "bb_count", "block_spread", "enreg_count", +}; + //------------------------------------------------------------------------ // CSE_HeuristicRL: construct RL CSE heuristic // @@ -5165,9 +5457,20 @@ CSE_HeuristicCommon* Compiler::optGetCSEheuristic() // Enable optional policies // - // RL takes precedence + // RL hook takes precedence // if (optCSEheuristic == nullptr) + { + bool useRLHook = (JitConfig.JitRLHook() > 0); + + if (useRLHook) + { + optCSEheuristic = new (this, CMK_CSE) CSE_HeuristicRLHook(this); + } + } + + // then RL + if (optCSEheuristic == nullptr) { bool useRLHeuristic = (JitConfig.JitRLCSE() != nullptr); diff --git a/src/coreclr/jit/optcse.h b/src/coreclr/jit/optcse.h index b8dd9fae685df7..48fba4e50e13a8 100644 --- a/src/coreclr/jit/optcse.h +++ b/src/coreclr/jit/optcse.h @@ -217,6 +217,49 @@ class CSE_HeuristicParameterized : public CSE_HeuristicCommon #ifdef DEBUG +// General Reinforcement Learning CSE heuristic hook. +// +// Produces a wide set of data to train a RL model. +// Consumes the decisions made by a model to perform CSEs. +// +class CSE_HeuristicRLHook : public CSE_HeuristicCommon +{ +private: + static const char* const s_featureNameAndType[]; + + void GetFeatures(CSEdsc* cse, int* features); + + enum + { + maxFeatures = 19, + }; + + enum + { + rlHookTypeOther = 0, + rlHookTypeInt = 1, + rlHookTypeLong = 2, + rlHookTypeFloat = 3, + rlHookTypeDouble = 4, + rlHookTypeStruct = 5, + rlHookTypeSimd = 6, + }; + +public: + CSE_HeuristicRLHook(Compiler*); + void ConsiderCandidates(); + bool ConsiderTree(GenTree* tree, bool isReturn); + + const char* Name() const + { + return "RL Hook CSE Heuristic"; + } + +#ifdef DEBUG + virtual void DumpMetrics(); +#endif +}; + // Reinforcement Learning CSE heuristic // // Uses a "linear" feature model with diff --git a/src/coreclr/scripts/cse_ml/.gitignore b/src/coreclr/scripts/cse_ml/.gitignore new file mode 100644 index 00000000000000..0d41b55212ac2a --- /dev/null +++ b/src/coreclr/scripts/cse_ml/.gitignore @@ -0,0 +1,2 @@ +# The root .gitignore doens't mark this as ignored: +__pycache__ diff --git a/src/coreclr/scripts/cse_ml/.pylintrc b/src/coreclr/scripts/cse_ml/.pylintrc new file mode 100644 index 00000000000000..fbd1f62fa6b168 --- /dev/null +++ b/src/coreclr/scripts/cse_ml/.pylintrc @@ -0,0 +1,5 @@ +[FORMAT] +max-line-length = 120 + +[DESIGN] +max-args=8 diff --git a/src/coreclr/scripts/cse_ml/evaluate.py b/src/coreclr/scripts/cse_ml/evaluate.py new file mode 100755 index 00000000000000..e14f48cce78e69 --- /dev/null +++ b/src/coreclr/scripts/cse_ml/evaluate.py @@ -0,0 +1,225 @@ +#!/usr/bin/python + +"""Evaluates a model on a given dataset.""" +from enum import Enum +import os +import argparse +import shutil +import numpy as np +import pandas +import tqdm + +from jitml import SuperPmi, SuperPmiContext, JitCseModel, MethodContext, JitCseEnv, split_for_cse +from train import validate_core_root + +class ModelResult(Enum): + """Analysis errors.""" + OK = 0 + JIT_FAILED = 1 + +def set_result(data, m_id, heuristic_score, no_cse_score, model_score, error = ModelResult.OK): + """Sets the results for the given method id.""" + data["method_id"].append(m_id) + data["heuristic_score"].append(heuristic_score) + data["no_cse_score"].append(no_cse_score) + data["model_score"].append(model_score) + data["failed"].append(error) + +def get_most_likley_allowed_action(jitrl : JitCseModel, method : MethodContext, can_terminate : bool): + """Returns the most likely allowed actions.""" + obs = JitCseEnv.get_observation(method) + probabilities = jitrl.action_probabilities(obs) + + # If we are not allowed to terminate, remove the terminate action. + terminate_action = len(probabilities) - 1 + if not can_terminate: + probabilities = probabilities[:-1] + + # Sorted by most likely action descending + sorted_actions = np.flip(np.argsort(probabilities)) + candidates = method.cse_candidates + + for action in sorted_actions: + if action == terminate_action: + return None + + if action < len(candidates) and candidates[action].can_apply: + return action + + # We are supposed to terminate instead of applying a CSE if none are available. + # If we got here there's some kind of error. + raise ValueError("No valid action found.") + +def test_model(superpmi : SuperPmi, jitrl : JitCseModel, method_ids, model_name): + """Tests the model on the test set.""" + data = { + "method_id" : [], + "heuristic_score" : [], + "no_cse_score" : [], + "model_score" : [], + "failed" : [] + } + + for m_id in tqdm.tqdm(method_ids, + desc=f"Processing {model_name}", + colour='green', + ncols=shutil.get_terminal_size().columns - 8, + ascii=False): + # the original JIT method + original = superpmi.jit_method(m_id, JitMetrics=1) + no_cse = superpmi.jit_method(m_id, JitMetrics=1, JitRLHook=1, JitRLHookCSEDecisions=[]) + + if original is None or no_cse is None: + set_result(data, m_id, 0, 0, 0, ModelResult.JIT_FAILED) + continue + + choices = [] + results = [] + while True: + prev_method = results[-1] if results else no_cse + + # If we have no more CSEs to apply, we are done. We expect this not to happen on the first + # iteration because we filter out methods that have no CSEs to apply. + if not any(x.can_apply for x in prev_method.cse_candidates): + set_result(data, m_id, original.perf_score, no_cse.perf_score, prev_method.perf_score) + + assert choices # We must have made at least one selection + break + + action = get_most_likley_allowed_action(jitrl, prev_method, choices) + if action is None: + set_result(data, m_id, original.perf_score, no_cse.perf_score, prev_method.perf_score) + break + + # apply the CSE + choices.append(action) + new_method = superpmi.jit_method(m_id, JitMetrics=1, JitRLHook=1, JitRLHookCSEDecisions=choices) + if new_method is None: + set_result(data, m_id, original.perf_score, no_cse.perf_score, prev_method.perf_score, + ModelResult.JIT_FAILED) + break + + results.append(new_method) + + # mark choices as applied + for c in choices: + new_method.cse_candidates[c].applied = True + + return pandas.DataFrame(data) + +def evaluate(superpmi, jitrl, methods, model_name, csv_file) -> pandas.DataFrame: + """Evaluate the model and save to the specified CSV file.""" + if os.path.exists(csv_file): + return pandas.read_csv(csv_file) + + result = test_model(superpmi, jitrl, methods, model_name) + result.to_csv(csv_file) + return result + +def enumerate_models(dir_or_file): + """Enumerates the models in the specified directory.""" + if os.path.isfile(dir_or_file): + return [dir_or_file] + + def extract_number(file): + return int(file.split("_")[-1]) if file.split("_")[-1].isdigit() else 100000000 + + files = [os.path.splitext(file)[0] for file in os.listdir(dir_or_file) if file.endswith(".zip")] + return sorted(files, key=extract_number, reverse=True) + +def print_result(result, model, kind): + """Prints the results.""" + + print('-' * 40 + f" {model} results " + '-' * 40) + print() + + print(f"{kind} results:") + print() + + print("Comparisons:") + no_jit_failure = result[result['failed'] != ModelResult.JIT_FAILED] + + # next calculate how often we improved on the heuristic + improved = no_jit_failure[no_jit_failure['model_score'] < no_jit_failure['heuristic_score']] + underperformed = no_jit_failure[no_jit_failure['model_score'] > no_jit_failure['heuristic_score']] + print(f"Better than heuristic: {len(improved)}") + print(f"Worse than heuristic: {len(underperformed)}") + print(f"Same as heuristic: {len(no_jit_failure) - len(improved) - len(underperformed)}") + print(f"Total: {len(result)}") + print() + + # sum up total difference + h_score = no_jit_failure['heuristic_score'].sum() + heuristic_diff = no_jit_failure['model_score'].sum() - h_score + print(f"Total heuristic difference: {heuristic_diff} ({heuristic_diff / len(no_jit_failure)} per method)") + print(f"Pct improvement: {-heuristic_diff / h_score * 100:.2f}%") + nc_score = no_jit_failure['no_cse_score'].sum() + no_cse_diff = no_jit_failure['model_score'].sum() - nc_score + print(f"Total no CSE difference: {no_cse_diff} ({no_cse_diff / len(no_jit_failure)} per method)") + print(f"Pct improvement: {-no_cse_diff / nc_score * 100:.2f}%") + print() + + # next calculate how often we improved on the no CSE score + improved = no_jit_failure[no_jit_failure['model_score'] < no_jit_failure['no_cse_score']] + underperformed = no_jit_failure[no_jit_failure['model_score'] > no_jit_failure['no_cse_score']] + print(f"Better than no CSE: {len(improved)}") + print(f"Worse than no CSE: {len(underperformed)}") + print(f"Same as no CSE: {len(no_jit_failure) - len(improved) - len(underperformed)}") + print() + + print("Failures:") + print(f"Failed: {len(result[result['failed'] != ModelResult.OK])}") + print(f"JIT Failed: {len(result[result['failed'] == ModelResult.JIT_FAILED])}") + print() + +def parse_args(): + """usage: train.py [-h] [--core_root CORE_ROOT] [--parallel n] [--iterations i] model_path mch""" + parser = argparse.ArgumentParser() + parser.add_argument("model_path", help="The directory or model path to load from.") + parser.add_argument("mch", help="The mch file of functions to evaluate the model with.") + parser.add_argument("--core_root", default=None, help="The coreclr root directory.") + parser.add_argument("--algorithm", default="PPO", help="The algorithm to use. (default: PPO)") + + args = parser.parse_args() + args.core_root = validate_core_root(args.core_root) + return args + +def main(args): + """Main entry point.""" + dir_or_path = args.model_path + + if not os.path.exists(dir_or_path): + raise FileNotFoundError(f"Path {dir_or_path} does not exist.") + + # Load data. + spmi_file = args.mch + ".json" + if os.path.exists(spmi_file): + spmi_context = SuperPmiContext.load(spmi_file) + else: + print(f"Creating SuperPmiContext '{spmi_file}', this may take several minutes...") + spmi_context = SuperPmiContext.create_from_mch(args.mch, args.core_root) + spmi_context.save(spmi_file) + + test_methods, training_methods = split_for_cse(spmi_context.methods, 0.1) + + for file in enumerate_models(dir_or_path): + print(file) + with spmi_context.create_superpmi() as superpmi: + # load the underlying model + jitrl = JitCseModel(args.algorithm) + jitrl.load(os.path.join(dir_or_path, file)) + + print(f"Evaluting model {file} on training and test data:") + + model_name = os.path.splitext(file)[0] + + filename = os.path.join(dir_or_path, f"{model_name}_test.csv") + result = evaluate(superpmi, jitrl, test_methods, model_name, filename) + print_result(result, model_name, "Test") + + filename = os.path.join(dir_or_path, f"{model_name}_train.csv") + result = evaluate(superpmi, jitrl, training_methods, model_name, filename) + print_result(result, model_name, "Train") + +if __name__ == "__main__": + main(parse_args()) diff --git a/src/coreclr/scripts/cse_ml/img/training.png b/src/coreclr/scripts/cse_ml/img/training.png new file mode 100644 index 00000000000000..819953be1adeb1 Binary files /dev/null and b/src/coreclr/scripts/cse_ml/img/training.png differ diff --git a/src/coreclr/scripts/cse_ml/jitml/__init__.py b/src/coreclr/scripts/cse_ml/jitml/__init__.py new file mode 100644 index 00000000000000..17862ccadeae6c --- /dev/null +++ b/src/coreclr/scripts/cse_ml/jitml/__init__.py @@ -0,0 +1,21 @@ +"""JIT Machine Learning (JITML) is a Python library for the .Net JIT's reinforcement learning algorithms.""" +from .method_context import MethodContext, CseCandidate, JitType +from .superpmi import SuperPmi, SuperPmiContext +from .jit_cse import JitCseEnv +from .machine_learning import JitCseModel +from .wrappers import OptimalCseWrapper, NormalizeFeaturesWrapper +from .constants import is_acceptable_for_cse, split_for_cse + +__all__ = [ + SuperPmi.__name__, + SuperPmiContext.__name__, + JitCseEnv.__name__, + JitCseModel.__name__, + MethodContext.__name__, + CseCandidate.__name__, + JitType.__name__, + OptimalCseWrapper.__name__, + NormalizeFeaturesWrapper.__name__, + is_acceptable_for_cse.__name__, + split_for_cse.__name__, +] diff --git a/src/coreclr/scripts/cse_ml/jitml/constants.py b/src/coreclr/scripts/cse_ml/jitml/constants.py new file mode 100644 index 00000000000000..0cb0d92b7b8820 --- /dev/null +++ b/src/coreclr/scripts/cse_ml/jitml/constants.py @@ -0,0 +1,50 @@ +"""Constants and parameters for the project.""" + +from typing import Sequence + +import numpy as np +from .method_context import MethodContext + +MIN_CSE = 3 +MAX_CSE = 16 + +INVALID_ACTION_PENALTY = -0.05 +INVALID_ACTION_LIMIT = 20 + +def is_acceptable_for_cse(method): + """Returns True if the method is acceptable for training on JitCseEnv.""" + applicable = len([x for x in method.cse_candidates if x.viable]) + return MIN_CSE <= applicable and len(method.cse_candidates) <= MAX_CSE + +def split_for_cse(methods : Sequence['MethodContext'], test_percent=0.1): + """Splits the methods into those that can be used for training and those that can't. + Returns the test and train sets.""" + method_by_cse = {} + + for x in methods: + if is_acceptable_for_cse(x): + method_by_cse.setdefault(x.num_cse, []).append(x) + + # convert method_by_cse to a list of methods + methods_list = [] + for value in method_by_cse.values(): + methods_list.append(value) + + test = [] + train = [] + + # use a fixed seed so subsequent calls line up + # Sort the groups of methods by length to ensure we don't care what order we process them in. + # Then sort each method by id before shuffling to (again) ensure we get the same result. + methods_list.sort(key=len) + for method_group in methods_list: + split = int(len(method_group) * test_percent) + + # Discard any groups that are too small to split. + if split > 0: + method_group.sort(key=lambda x: x.index) + np.random.default_rng(seed=42).shuffle(method_group) + test.extend(method_group[:split]) + train.extend(method_group[split:]) + + return test, train diff --git a/src/coreclr/scripts/cse_ml/jitml/jit_cse.py b/src/coreclr/scripts/cse_ml/jitml/jit_cse.py new file mode 100644 index 00000000000000..aa9191e8faef6e --- /dev/null +++ b/src/coreclr/scripts/cse_ml/jitml/jit_cse.py @@ -0,0 +1,249 @@ +"""A gymnasium environment for training RL to optimize the .Net JIT's CSE usage.""" + +from typing import Any, Dict, List, Optional +import gymnasium as gym +import numpy as np + +from .method_context import JitType, MethodContext +from .superpmi import SuperPmi, SuperPmiContext +from .constants import (INVALID_ACTION_PENALTY, INVALID_ACTION_LIMIT, MAX_CSE, is_acceptable_for_cse) + +# observation space +JITTYPE_ONEHOT_SIZE = 6 +BOOLEAN_FEATURES = 7 +FLOAT_FEATURES = 9 +FEATURES = JITTYPE_ONEHOT_SIZE + BOOLEAN_FEATURES + FLOAT_FEATURES + +# Scale up the reward to make it more meaningful. +REWARD_SCALE = 5.0 + +class JitCseEnv(gym.Env): + """A gymnasium environment for CSE optimization selection in the JIT.""" + observation_columns : List[str] = [f"type_{JitType(i).name.lower()}" for i in range(1, 7)] + \ + [ + "can_apply", "live_across_call", "const", "shared_const", "make_cse", "has_call", "containable", + "cost_ex", "cost_sz", "use_count", "def_count", "use_wt_cnt", "def_wt_cnt", "distinct_locals", + "local_occurrences", "enreg_count" + ] + + def __init__(self, context : SuperPmiContext, methods : Optional[List[int]] = None, **kwargs): + super().__init__(**kwargs) + + self.pmi_context = context + self.methods = methods or context.training_methods + if not self.methods: + raise ValueError("No methods to train on.") + + self.__superpmi : SuperPmi = None + self.action_space = gym.spaces.Discrete(MAX_CSE + 1) + self.observation_space = gym.spaces.Box(np.zeros((MAX_CSE, FEATURES)), + np.ones((MAX_CSE, FEATURES)), + dtype=np.float32) + + self.last_info : Optional[Dict[str,object]] = None + + def __del__(self): + self.close() + + def close(self): + """Closes the environment and cleans up resources.""" + super().close() + if self.__superpmi is not None: + self.__superpmi.stop() + self.__superpmi = None + + def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): + super().reset(seed=seed, options=options) + self.last_info = None + + failure_count = 0 + while True: + index = self.__select_method() + no_cse = self._jit_method_with_cleanup(index, JitMetrics=1, JitRLHook=1, JitRLHookCSEDecisions=[]) + original_heuristic = self._jit_method_with_cleanup(index, JitMetrics=1) + if no_cse and original_heuristic: + break + + failure_count += 1 + if failure_count > 512: + raise ValueError("No valid methods found") + + observation = self.get_observation(no_cse) + self.last_info = { + 'invalid_actions' : 0, + 'method_index' : index, + 'heuristic_method' : original_heuristic, + 'no_cse_method' : no_cse, + 'current' : no_cse, + 'total_reward' : 0.0, + 'observation' : observation, + 'action_is_valid' : None + } + + return observation, self.last_info + + def step(self, action): + # the last action is always to terminate + if action == self.action_space.n - 1: + action = None + + last_info = self.last_info + if last_info is None: + raise ValueError("Must call reset() before step()") + + info = last_info.copy() + self.last_info = None + + # update action, ensure we have an up to date observation + info['action'] = action + del info['observation'] + + # Note that we have not yet updated the info dictionary for previous and current, which means + # info['current'] is the previous method at this point. We do not update info's previous/current + # until we are sure the method is JIT'ed successfully. + previous = info['current'] + + # Ensure the selected action is valid. + info['action_is_valid'] = self._is_valid_action(action, previous) + if info['action_is_valid']: + current = self._jit_method_with_cleanup(info['method_index'], JitMetrics=1, JitRLHook=1, + JitRLHookCSEDecisions=previous.cses_chosen + [action]) + + if current is not None: + observation = self.get_observation(current) + truncated = False + terminated = not current.cse_candidates or action is None + reward = self.get_rewards(previous, current) + + info['previous'] = previous + info['current'] = current + + else: + # Don't set current or observation, as we should not be using them. + observation = last_info['observation'] + truncated = True + terminated = False + reward = INVALID_ACTION_PENALTY + + else: + # action was invalid + info['invalid_actions'] += 1 + + truncated = info['invalid_actions'] >= INVALID_ACTION_LIMIT + terminated = False + observation = last_info['observation'] + reward = INVALID_ACTION_PENALTY + + info['observation'] = observation + info['total_reward'] += reward + info['terminated'] = terminated + info['truncated'] = truncated + + # These are reported only once, when the episode is done. + if terminated: + info['heuristic_score'] = info['heuristic_method'].perf_score + info['no_cse_score'] = info['no_cse_method'].perf_score + info['total_reward'] = info['total_reward'] + info['invalid_actions'] = info['invalid_actions'] + if 'current' in info: + info['final_score'] = info['current'].perf_score + + self.last_info = info + return observation, reward, terminated, truncated, info + + def get_rewards(self, prev_method : MethodContext, curr_method : MethodContext): + """Returns the reward based on the change in performance score.""" + prev = prev_method.perf_score + curr = curr_method.perf_score + + # should not happen + if np.isclose(prev, 0.0): + return 0.0 + + return REWARD_SCALE * (prev - curr) / prev + + def _is_valid_action(self, action, method): + # Terminating is only valid if we have performed a CSE. Doing no CSEs isn't allowed. + if action is None: + return bool(method.cses_chosen) + + candidate = method.cse_candidates[action] if action < len(method.cse_candidates) else None + return candidate is not None and candidate.can_apply + + @classmethod + def get_observation(cls, method : MethodContext, fill=True): + """Builds the observation from a method without normalizing the data.""" + tensors = [] + for cse in method.cse_candidates: + tensor = [] + + # one-hot encode the type + one_hot = [0.0] * 6 + one_hot[cse.type - 1] = 1.0 + tensor.extend(one_hot) + + # boolean features + tensor.extend([ + cse.can_apply, cse.live_across_call, cse.const, cse.shared_const, cse.make_cse, cse.has_call, + cse.containable + ]) + + # float features + tensor.extend([ + cse.cost_ex, cse.cost_sz, cse.use_count, cse.def_count, cse.use_wt_cnt, cse.def_wt_cnt, + cse.distinct_locals, cse.local_occurrences, cse.enreg_count + ]) + + tensors.append(tensor) + + if fill: + while len(tensors) < MAX_CSE: + tensors.append([0.0] * FEATURES) + + observation = np.vstack(tensors) + return observation + + + def _jit_method_with_cleanup(self, m_id, *args, **kwargs): + """Jits a method, but if it fails, we remove it from future consideration. Note that the + SuperPmi class will retry before returning None, so we know this method is not going to work.""" + superpmi = self.__get_or_create_superpmi() + + result = superpmi.jit_method(m_id, retry=2, *args, **kwargs) + if result is None: + self.__remove_method(m_id) + + elif np.isclose(result.perf_score, 0.0): + self.__remove_method(m_id) + result = None + + return result + + def __select_method(self): + if self.methods is None: + superpmi = self.__get_or_create_superpmi() + self.methods = [x.index for x in superpmi.enumerate_methods() if is_acceptable_for_cse(x)] + + return np.random.choice(self.methods) + + def __remove_method(self, index): + if self.methods is None: + return + + self.methods = [x for x in self.methods if x != index] + + def __get_or_create_superpmi(self): + if self.__superpmi is None: + self.__superpmi = self.pmi_context.create_superpmi() + self.__superpmi.start() + + return self.__superpmi + + def render(self) -> None: + info = self.last_info + if info is not None: + print(f"{info['method_index']} heuristic_score: {info['heuristic_method'].perf_score} " + f"no_cse_score: {info['no_cse_method'].perf_score} choices:{info['current'].cses_chosen} " + f"invalid_count:{info['invalid_actions']} ({info['current'].name})") + +__all__ = [JitCseEnv.__name__] diff --git a/src/coreclr/scripts/cse_ml/jitml/machine_learning.py b/src/coreclr/scripts/cse_ml/jitml/machine_learning.py new file mode 100644 index 00000000000000..f4af6c1c766778 --- /dev/null +++ b/src/coreclr/scripts/cse_ml/jitml/machine_learning.py @@ -0,0 +1,217 @@ +"""The default machine learning agent which drives CSE optimization.""" + +# This file is expected to contain all of the torch/stable-baselines3 related code. If possible, +# it would be best to avoid spilling those concepts outside of this file. This is so that JitCseEnv can +# be used without requiring folks to use torch/stable-baselines3 and instead can use their own model. + +import os +import json +from typing import List, Optional + +import torch +import numpy as np + +from stable_baselines3 import A2C, DQN, PPO +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.vec_env import SubprocVecEnv +import gymnasium as gym + +from .method_context import MethodContext +from .jit_cse import JitCseEnv +from .superpmi import SuperPmiContext + +class JitCseModel: + """The raw implementation of the machine learning agent.""" + def __init__(self, algorithm, device='auto', make_env=None, ent_coef=0.01, verbose=False): + if algorithm not in ('PPO', 'A2C', 'DQN'): + raise ValueError(f"Unknown algorithm {algorithm}. Must be one of: PPO, A2C, DQN") + + self.algorithm = algorithm + self.device = device + self.ent_coef = ent_coef + self.verbose = verbose + self.make_env = make_env + self._model = None + + def load(self, path): + """Loads the model from the specified path.""" + alg = self.__get_algorithm() + self._model = alg.load(path, device=self.device) + return self._model + + def save(self, path): + """Saves the model to the specified path.""" + self._model.save(path) + + @property + def num_timesteps(self): + """Returns the number of timesteps the model has been trained for.""" + return self._model.num_timesteps if self._model is not None else 0 + + def predict(self, obs, deterministic = False): + """Predicts the action to take based on the observation.""" + action, _ = self._model.predict(obs, deterministic=deterministic) + return action + + def action_probabilities(self, obs): + """Gets the probability of every action.""" + obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).to(self._model.device) + action_distribution = self._model.policy.get_distribution(obs_tensor) + probs = action_distribution.distribution.probs + return probs.cpu().detach().numpy()[0] + + def train(self, pmi_context : SuperPmiContext, training_methods : List[MethodContext], output_dir : str, + iterations = None, parallel = None, progress_bar = True, + wrappers : Optional[List[gym.Wrapper]] = None) -> str: + """Trains a model from scratch. + + Args: + pmi_context: The SuperPmiContext to use for training. + training_methods : The methods to train on. + output_dir: The directory to save the model to. + iterations: The number of iterations to train for. Defaults to 100,000. + parallel: The number of parallel environments to use. Defaults to single-process (None). + progress_bar: Whether to display a progress bar. Defaults to True. + + Returns: + The full path to the trained model. + """ + training_methods = [m.index for m in training_methods] + os.makedirs(output_dir, exist_ok=True) + + def default_make_env(): + env = JitCseEnv(pmi_context, training_methods) + if wrappers: + for wrapper in wrappers: + env = wrapper(env) + return env + + make_env = self.make_env or default_make_env + if parallel is not None and parallel > 1: + env = make_vec_env(make_env, n_envs=parallel, vec_env_cls=SubprocVecEnv) + else: + env = make_env() + + try: + self._model = self._create(env, tensorboard_log=os.path.join(output_dir, 'logs')) + + iterations = 100_000 if iterations is None else iterations + callback = LogCallback(self._model, output_dir) if self.algorithm in ('PPO', 'A2C') else None + self._model.learn(iterations, progress_bar=progress_bar, callback=callback) + + save_path = os.path.join(output_dir, self.algorithm.lower() +'.zip') + self.save(save_path) + return save_path + + finally: + env.close() + + def _create(self, env, **kwargs): + alg = self.__get_algorithm() + if alg == PPO: + return alg('MlpPolicy', env, device=self.device, ent_coef=self.ent_coef, verbose=self.verbose, **kwargs) + + return alg('MlpPolicy', env, device=self.device, verbose=self.verbose, **kwargs) + + def __get_algorithm(self): + match self.algorithm: + case 'PPO': + return PPO + case 'A2C': + return A2C + case 'DQN': + return DQN + case _: + raise ValueError(f"Unknown algorithm {self.algorithm}. Must be one of: PPO, A2C, DQN") + +class LogCallback(BaseCallback): + """A callback to log reward values to tensorboard and save the best models.""" + # pylint: disable=too-many-instance-attributes + + def __init__(self, model : PPO | A2C, save_dir : str, last_model_freq = 500_000): + super().__init__() + + self.model = model + self.next_save = model.n_steps + self.last_model_freq = last_model_freq + self.last_model_next_save = self.last_model_freq + + self.best_reward = -np.inf + + self.save_dir = save_dir + + self._rewards = [] + self._invalid_choices = [] + self._result_vs_heuristic = [] + self._result_vs_no_cse = [] + self._better_or_worse = [] + self._choice_count = [] + + + def _on_step(self) -> bool: + self._update_stats() + + if self.n_calls > self.next_save: + self.next_save += self.model.n_steps + + rew_mean = np.mean(self._rewards) if self._rewards else -np.inf + if rew_mean > self.best_reward: + self.best_reward = rew_mean + self._save_incremental(rew_mean, os.path.join(self.save_dir, 'best_reward.zip')) + + if self.model.num_timesteps >= self.last_model_next_save: + self.last_model_next_save += self.last_model_freq + self._save_incremental(rew_mean, os.path.join(self.save_dir, f'ppo_{self.model.num_timesteps}.zip')) + + if self._invalid_choices: + self.logger.record('results/invalid_choices', np.mean(self._invalid_choices)) + + if self._result_vs_heuristic: + self.logger.record('results/vs_heuristic', np.mean(self._result_vs_heuristic)) + + if self._result_vs_no_cse: + self.logger.record('results/vs_no_cse', np.mean(self._result_vs_no_cse)) + + if self._better_or_worse: + self.logger.record('results/better_than_heuristic', np.mean(self._better_or_worse)) + + if self._choice_count: + self.logger.record('results/num_cse', np.mean(self._choice_count)) + + self._rewards.clear() + self._invalid_choices.clear() + self._result_vs_heuristic.clear() + self._result_vs_no_cse.clear() + self._better_or_worse.clear() + self._choice_count.clear() + + return True + + def _update_stats(self): + for info in self.locals['infos']: + if 'final_score' not in info: + continue + + final = info['final_score'] + heuristic = info['heuristic_score'] + no_cse = info['no_cse_score'] + + if heuristic != 0: + self._result_vs_heuristic.append((heuristic - final) / heuristic) + + if no_cse != 0: + self._result_vs_no_cse.append((no_cse - final) / no_cse) + + self._better_or_worse.append(1 if final < heuristic else -1 if final < heuristic else 0) + self._choice_count.append(len(info['current'].cses_chosen)) + self._rewards.append(info['total_reward']) + + def _save_incremental(self, reward, save_path): + self.model.save(save_path) + + metadata = { "iterations" : self.num_timesteps, 'reward' : reward} + with open(save_path + '.json', 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=4) + +__all__ = [JitCseModel.__name__] diff --git a/src/coreclr/scripts/cse_ml/jitml/method_context.py b/src/coreclr/scripts/cse_ml/jitml/method_context.py new file mode 100644 index 00000000000000..075f47028dc8b1 --- /dev/null +++ b/src/coreclr/scripts/cse_ml/jitml/method_context.py @@ -0,0 +1,92 @@ +"""A wrapper around the CSE_Candidate and MethodContext classes. CseCandidate mirrors code in +src/coreclr/jit/optcse.cpp.""" + +from enum import Enum +from typing import List, Optional +from pydantic import BaseModel, ValidationError, field_validator + +class JitType(Enum): + """The type of a CSE candidate. Mirrors CSE_HeuristicRLHook's enum.""" + OTHER : int = 0 + INT : int = 1 + LONG : int = 2 + FLOAT : int = 3 + DOUBLE : int = 4 + STRUCT : int = 5 + SIMD : int = 6 + +class CseCandidate(BaseModel): + """A CSE candidate. Mirrors CSE_Candidate features in CSE_HeuristicRLHook.cpp.""" + index : int + applied : Optional[bool] = False + viable : bool + live_across_call : bool + const : bool + shared_const : bool + make_cse : bool + has_call : bool + containable : bool + type : int + cost_ex : int + cost_sz : int + use_count : int + def_count : int + use_wt_cnt : int + def_wt_cnt : int + distinct_locals : int + local_occurrences : int + bb_count : int + block_spread : int + enreg_count : int + for_testing : Optional[bool] = False + + @field_validator('applied', 'viable', 'live_across_call', 'const', 'shared_const', 'make_cse', 'has_call', + 'containable', mode='before') + @classmethod + def validate_bool(cls, v): + """Validates that the value is a boolean or is a 0 or 1.""" + if isinstance(v, int) and v in [0, 1]: + return bool(v) + + if isinstance(v, bool): + return v + + raise ValidationError(f"Value must be either 1, 0, or a boolean, got {v}") + + @property + def can_apply(self): + """Returns True if the candidate is viable and not applied.""" + return self.viable and not self.applied + +class MethodContext(BaseModel): + """A superpmi method context.""" + index : int + name : str + hash : str + total_bytes : int + prolog_size : int + instruction_count : int + perf_score : float + bytes_allocated : int + num_cse : int + num_cse_candidate : int + heuristic : str + cses_chosen : List[int] + cse_candidates : List[CseCandidate] + + def __str__(self): + return f"{self.index}: {self.name}" + + # validate that perf_score is never negative: + @field_validator('perf_score', mode='before') + @classmethod + def _validate_perf_score(cls, v): + if v < 0: + raise ValueError("perf_score must not be negative") + return v + +__all__ = [ + CseCandidate.__name__, + MethodContext.__name__, + JitType.__name__ +] diff --git a/src/coreclr/scripts/cse_ml/jitml/superpmi.py b/src/coreclr/scripts/cse_ml/jitml/superpmi.py new file mode 100644 index 00000000000000..db08f6f408c61c --- /dev/null +++ b/src/coreclr/scripts/cse_ml/jitml/superpmi.py @@ -0,0 +1,243 @@ +"""Functions for interacting with SuperPmi.""" + +import json +import os +import subprocess +import re +from typing import Iterable, List, Optional +from pydantic import BaseModel, field_validator + +from .method_context import MethodContext + +class SuperPmiContext(BaseModel): + """Information about how to construct a SuperPmi object. This tells us where to find CLR's CORE_ROOT with + the superpmi and jit, and which .mch file to use. Additionally, it tells us which methods to use for training + and testing.""" + core_root : str + mch : str + jit : Optional[str] = None + methods : Optional[List[MethodContext]] = [] + + @field_validator('core_root', 'mch', mode='before') + @classmethod + def _validate_path(cls, v): + if not os.path.exists(v): + raise FileNotFoundError(f"{v} does not exist.") + + return v + + @field_validator('jit', mode='before') + @classmethod + def _validate_optional_path(cls, v): + if v is not None and not os.path.exists(v): + raise FileNotFoundError(f"{v} does not exist.") + + return v + + @staticmethod + def create_from_mch(mch : str, core_root : str, jit : Optional[str] = None) -> 'SuperPmiContext': + """Loads the SuperPmiContext from the specified arguments.""" + result = SuperPmiContext(core_root=core_root, mch=mch, jit=jit) + + methods = [] + with SuperPmi(result) as superpmi: + for method in superpmi.enumerate_methods(): + methods.append(method) + + result.methods = methods + return result + + def save(self, file_path:str): + """Saves the SuperPmiContext to a file.""" + with open(file_path, 'w', encoding="utf8") as f: + json.dump(self.model_dump(), f) + + @staticmethod + def load(file_path:str): + """Loads the SuperPmiContext from a file.""" + if not os.path.exists(file_path): + raise FileNotFoundError(f"{file_path} does not exist.") + + with open(file_path, 'r', encoding="utf8") as f: + data = json.load(f) + return SuperPmiContext(**data) + + def create_superpmi(self, verbosity:str = 'q'): + """Creates a SuperPmi object from this context.""" + return SuperPmi(self, verbosity) + +class SuperPmi: + """Controls one instance of superpmi.""" + def __init__(self, context : SuperPmiContext, verbosity:str = 'q'): + """Constructor. + core_root is the path to the coreclr build, usually at [repo]/artifiacts/bin/coreclr/[arch]/. + verbosity is the verbosity level of the superpmi process. Default is 'q'.""" + self._process = None + self._feature_names = None + self.context = context + self.verbose = verbosity + + if os.name == 'nt': + self.superpmi_path = os.path.join(context.core_root, 'superpmi.exe') + self.jit_path = os.path.join(context.core_root, context.jit if context.jit else 'clrjit.dll') + else: + self.superpmi_path = os.path.join(context.core_root, 'superpmi') + self.jit_path = os.path.join(context.core_root, context.jit if context.jit else 'libclrjit.so') + + if not os.path.exists(self.superpmi_path): + raise FileNotFoundError(f"superpmi {self.superpmi_path} does not exist.") + + if not os.path.exists(self.jit_path): + raise FileNotFoundError(f"jit {self.jit_path} does not exist.") + + def __del__(self): + self.stop() + + def __enter__(self): + self.start() + return self + + def __exit__(self, *_): + self.stop() + + def jit_method(self, method_or_id : int | MethodContext, retry=1, **options) -> MethodContext: + """Attempts to jit the method, and retries if it fails up to "retry" times.""" + if retry < 1: + raise ValueError("retry must be greater than 0.") + + for _ in range(retry): + result = self.__jit_method(method_or_id, **options) + if result is not None: + return result + + self.stop() + self.start() + + return None + + def __jit_method(self, method_or_id : int | MethodContext, **options) -> MethodContext: + """Jits the method given by id or MethodContext.""" + process = self._process + if process is None: + raise ValueError("SuperPmi process is not running. Use a 'with' statement.") + + if isinstance(method_or_id, MethodContext): + method_or_id = method_or_id.index + + if "JitMetrics" not in options: + options["JitMetrics"] = 1 + + if self._feature_names is None and "JitRLHook" in options: + options['JitRLHookEmitFeatureNames'] = 1 + + torun = f"{method_or_id}!" + torun += "!".join([f"{key}={value}" for key, value in options.items()]) + + if not process.poll(): + self.stop() + process = self.start() + + process.stdin.write(f"{torun}\n".encode('utf-8')) + process.stdin.flush() + + result = None + output = "" + + while not output.startswith('[streaming] Done.'): + output = process.stdout.readline().decode('utf-8').strip() + if output.startswith(';'): + result = self._parse_method_context(output) + + return result + + def enumerate_methods(self) -> Iterable[MethodContext]: + """List all methods in the mch file.""" + params = [self.superpmi_path, self.jit_path, self.context.mch, '-v', 'q', '-jitoption', 'JitMetrics=1', + '-jitoption', 'JitRLHook=1', '-jitoption', 'JitRLHookEmitFeatureNames=1'] + + try: + # pylint: disable=consider-using-with + process = subprocess.Popen(params, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + for line in process.stdout: + line = line.decode('utf-8').strip() + if line.startswith(';'): + yield self._parse_method_context(line) + + finally: + if process.poll(): + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + + def _parse_method_context(self, line:str) -> MethodContext: + if self._feature_names is None: + # find featureNames in line + feature_names_header = 'featureNames ' + start = line.find(feature_names_header) + stop = line.find(' ', start + len(feature_names_header)) + if start > 0: + self._feature_names = line[start + len(feature_names_header):stop].split(',') + self._feature_names.insert(0, 'id') + + properties = {} + properties['index'] = int(re.search(r'spmi index (\d+)', line).group(1)) + properties['name'] = re.search(r'for method ([^ ]+):', line).group(1) + properties['hash'] = re.search(r'MethodHash=([0-9a-f]+)', line).group(1) + properties['total_bytes'] = int(re.search(r'Total bytes of code (\d+)', line).group(1)) + properties['prolog_size'] = int(re.search(r'prolog size (\d+)', line).group(1)) + properties['instruction_count'] = int(re.search(r'instruction count (\d+)', line).group(1)) + properties['perf_score'] = float(re.search(r'PerfScore ([0-9.]+)', line).group(1)) + properties['bytes_allocated'] = int(re.search(r'allocated bytes for code (\d+)', line).group(1)) + properties['num_cse'] = int(re.search(r'num cse (\d+)', line).group(1)) + properties['num_cse_candidate'] = int(re.search(r'num cand (\d+)', line).group(1)) + properties['heuristic'] = re.search(r'num cand \d+ (.+) ', line).group(1) + + seq = re.search(r'seq ([0-9,]+) spmi', line) + if seq is not None: + properties['cses_chosen'] = [int(x) for x in seq.group(1).split(',')] + else: + properties['cses_chosen'] = [] + + cse_candidates = None + if self._feature_names is not None: + # features CSE #032,3,10,3,3,150,150,1,1,0,0,0,0,0,0,37 + candidates = re.findall(r'features #([0-9,]+)', line) + if candidates is not None: + cse_candidates = [{self._feature_names[i]: int(x) for i, x in enumerate(candidate.split(','))} + for candidate in candidates] + + for i, candidate in enumerate(cse_candidates): + candidate['index'] = i + if i in properties['cses_chosen']: + candidate['applied'] = True + + properties['cse_candidates'] = cse_candidates if cse_candidates is not None else [] + + return MethodContext(**properties) + + def start(self): + """Starts and returns the superpmi process.""" + if self._process is None: + params = [self.superpmi_path, self.jit_path, '-streaming', 'stdin', self.context.mch] + if self.verbose is not None: + params.extend(['-v', self.verbose]) + + # pylint: disable=consider-using-with + self._process = subprocess.Popen(params, stdin=subprocess.PIPE, stdout=subprocess.PIPE) + + return self._process + + def stop(self): + """Closes the superpmi process.""" + if self._process is not None: + self._process.stdin.write(b"quit\n") + self._process.terminate() + self._process = None + +__all__ = [ + SuperPmi.__name__, + SuperPmiContext.__name__, +] diff --git a/src/coreclr/scripts/cse_ml/jitml/wrappers.py b/src/coreclr/scripts/cse_ml/jitml/wrappers.py new file mode 100644 index 00000000000000..56dbf12cbc2af8 --- /dev/null +++ b/src/coreclr/scripts/cse_ml/jitml/wrappers.py @@ -0,0 +1,117 @@ +"""A reward wrapper for the CSE environment that provides rewards based not just on the change in +performance score, but also on the quality of the CSE choices made.""" + +from typing import List, Optional, SupportsFloat +import gymnasium as gym +import numpy as np + +from .method_context import MethodContext +from .jit_cse import JitCseEnv +from .superpmi import SuperPmi + +OPTIMAL_BONUS = 0.05 +SUBOPTIMAL_PENALTY = -0.01 +NEUTRAL_PENALTY = -0.005 + +class OptimalCseWrapper(gym.Wrapper): + """A wrapper for the CSE environment that provides rewards based not just on the change in + performance score, but also on the quality of the CSE choices made.""" + def __init__(self, env : JitCseEnv): + super().__init__(env) + self.superpmi : SuperPmi = env.unwrapped.pmi_context.create_superpmi() + self.superpmi.start() + + def step(self, action): + """Steps the environment.""" + observation, reward, terminated, truncated, info = self.env.step(action) + reward = self._get_reward(reward, info) + return observation, reward, terminated, truncated, info + + def _get_reward(self, reward : SupportsFloat, info) -> SupportsFloat: + # We'll let the parent class handle the reward in these cases. + if info['truncated'] or not info['action_is_valid']: + return reward + + m_idx = info['method_index'] + current = info['current'] + previous = info['previous'] + previous_score = previous.perf_score + + # Did we choose to end optimization? + if info['action'] is None: + all_cses = self._get_all_cses(m_idx, previous, None) + best_perf_score = min(all_cses, key=lambda x: x.perf_score).perf_score if all_cses else np.inf + + if not np.isclose(best_perf_score, previous_score) and best_perf_score < previous_score: + reward += SUBOPTIMAL_PENALTY + + # Otherwise we chose a CSE + else: + # We apply a tiny penalty for choosing a CSE that matches the previous score. Choosing a CSE that + # doesn't change the score still has a cost, but we don't want this penalty to be so high that the + # agent avoids making choices. + if np.isclose(current.perf_score, previous_score): + reward += NEUTRAL_PENALTY + + # If we improved the performance score, give a bonus for choosing the best option out of all of them. + elif current.perf_score < previous_score: + # We improved the performance score, but was it the best choice? + all_cses = self._get_all_cses(m_idx, previous, current.cses_chosen[-1]) + best_perf_score = min(all_cses, key=lambda x: x.perf_score).perf_score if all_cses else np.inf + if np.isclose(best_perf_score, current.perf_score) or current.perf_score < best_perf_score: + reward += OPTIMAL_BONUS + + return reward + + def _get_all_cses(self, m_idx, previous : MethodContext, selected : Optional[int]) -> List[MethodContext]: + # If we aren't given a current method, then no CSEs were applied. + assert selected not in previous.cses_chosen + + all_cses = [self.superpmi.jit_method(m_idx, JitMetrics=1, JitRLHook=1, + JitRLHookCSEDecisions=previous.cses_chosen + [x.index]) + for x in previous.cse_candidates + if x.index != selected and x.can_apply] + + all_cses = [x for x in all_cses if x is not None] + return all_cses + + +class NormalizeFeaturesWrapper(gym.ObservationWrapper): + """Removes unused features from the observation space.""" + + # Calculated using scripts/calculate_feature_norm.py + obs_subtract = np.array([0.0, 0.0, 0.0, 0.0, 0, 0.0, 0.0, 0.0, 0.0, 0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0], dtype=np.float32) + obs_scale = np.array([1.0, 1.0, 1.0, 1.0, 1, 1.0, 1.0, 1.0, 1.0, 1, 1.0, 1.0, 1.0, 0.012658227848101266, + 0.018867924528301886, 0.001763668430335097, 0.005291005291005291, 0.06988495589628721, + 0.07344255107610509, 0.2, 0.16666666666666666, 0.0013623978201634877], dtype=np.float32) + obs_log1p = np.array([False, False, False, False, False, False, False, False, False, False, False, False, False, + False, False, False, False, True, True, False, False, False], dtype=bool) + + + unused_features = ['shared_const', 'type_struct'] + + def __init__(self, env): + super().__init__(env) + + # Remove the unused features from the observation space. + self.filter = np.array([name in NormalizeFeaturesWrapper.unused_features for name in env.observation_columns], + dtype=bool) + + self.observation_space = gym.spaces.Box( + low=env.observation_space.low[:, self.filter], + high=env.observation_space.high[:, self.filter], + dtype=env.observation_space.dtype + ) + + def observation(self, observation): + """Builds the observation from a method.""" + observation[:, self.obs_log1p] = np.log1p(observation[:, self.obs_log1p]) + observation = (observation - self.obs_subtract) * self.obs_scale + + # We still need to clip the data since there could be some values we didn't encounter when building + # the scaling factors + np.clip(observation, 0.0, 1.0, out=observation) + return observation[:, self.filter] + +__all__ = [NormalizeFeaturesWrapper.__name__, OptimalCseWrapper.__name__] diff --git a/src/coreclr/scripts/cse_ml/readme.md b/src/coreclr/scripts/cse_ml/readme.md new file mode 100644 index 00000000000000..21259fad336ad1 --- /dev/null +++ b/src/coreclr/scripts/cse_ml/readme.md @@ -0,0 +1,126 @@ +# Introduction + +This project is a Reinforcement Learning gymnasium used to train a machine learning model to choose when to apply the Common Subexpression Elimination optimization in the JIT. + +Currently, it *almost* matches the default, hand-written CSE heuristic in the JIT, and does so with a simple reward function and features which are not normalized or smartly chosen. This is intended to be a playground to try to find optimal features, reward function, neural network, and architecture for a fine-tuned CSE model. + +This project works best/easiest on Ubuntu 22 (WSL2 is fine), but it also works on Windows. + +# Setup + +Follow the standard [Workflow Guide](../../../../docs/workflow/README.md) to get a working environment to build .Net. + +## Runtime setup + +Build a checked version of the runtime and tests: + +```bash +./build.sh -subset clr -c Checked +./build.sh -subset libs -c Release -rc Checked +./src/tests/build.sh x64 checked skipmanaged skipnative +``` + +Download superpmi data. This will download a few gigs worth of data: + +```bash +python src/coreclr/scripts/superpmi.py download +``` + +## Python Setup + +This was developed and tested with Python 3.10 (3.11 on Windows). Python 3.10 is the default Python version in Ubuntu 22. + +First, install all dependencies from requirements.txt in this directory: + +```bash +pip install -r requirements.txt +``` + +## Training a Model + +This project uses SuperPMI to JIT methods without needing to load the runtime. SuperPMI records data about methods in Method Contexts, stored in a .mch file under the `artifacts/spmi` folder. To train a model, you need to specify the CORE_ROOT environment variable to point to the checked runtime we built above (or use the --core_root parameter) and specify a .mch file to use. Here is an example: + +```bash +python train.py ./model_output_path/ ~/git/dotnet/runtime/artifacts/[build]/[file].mch \ + --core_root ~/git/dotnet/runtime/artifacts/bin/coreclr/linux.x64.Checked/ \ + --iterations 5000000 --parallel 10 +``` + +This will train a model and store it in `./model_output_path`. Here are the command line options: + +``` bash +usage: train.py [-h] [--core_root CORE_ROOT] [--parallel PARALLEL] [--iterations ITERATIONS] [--algorithm ALGORITHM] + [--test-percent TEST_PERCENT] [--reward-optimal-cse] [--normalize-features] model_path mch +``` + +**algorithm** - PPO, A2C, or DQN. PPO is the default and currently the only one that seems to converge to a solution. Still working on getting A2C and DQN to work. + +**iterations** - The number of iterations (individual CSE choices) to train on. PPO builds a decent model at around 1 million iterations (the default). It starts getting close to to the default CSE Heuristic at around 3-5 million iterations. + +**parallel** - Use multiprocessing to train in parallel. This specifies the number of processes (default is 1). + +**test-percent** - What percentage of the .mch file to reserve for testing the model (default is .1, 10%). + +**reward-optimal-cse** - Attempt to find the "optimal" choice at each iteration step and reward the model for picking the best option. This does have a positive impact, but slows down training by 4x. + +**normalize-features** - Performs normalization on features, currently causes the model to not train. (So don't use it, still investigating why.) + +## Evaluating a Model + +Use `evaluate.py` to evaluate a model's performance. Simply pass in the model path, `.mch` file, and CORE_ROOT used with train.py and it will list how many methods were improved or regressed by using the model versus the default heuristic. Note that snapshots of the model are taken at regular intervals and this file will attempt to evaluate all of them. + +# The Code + +[jit_cse.py](jitml/jit_cse.py) - This contains the environment itself. This is meant to produce the most basic observation space and rewards. If you want to customize rewards, the features that the model uses, etc, you can create a gym wrapper that wraps the environment. + +[wrappers.py](jitml/wrappers.py) - This is an example of modifying the gym environment. `NormalizeFeaturesWrapper` is an example of a `gym.ObservationWrapper`. It attempts to normalize all inputs to the model in the range of `[0, 1]`. `OptimalCseWrapper` is a full `gym.Wrapper` that wraps the `step` function. It enhances the default reward function to attempt to reward/punish the model for making the exact correct or incorrect decisions. + +[machine_learning.py](jitml/machine_learning.py) - This file contains all of the machine learning implementation for this project. We currently just use stable-baselines to implement PPO/A2C/DQN. Additionally, we don't currently define our own neural network. The neural network architecture is pre-defined by the `MlpPolicy` parameter when creating the reinforcement learning agent. A custom neural network can be specified by building a [Custom Network Architecture](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#custom-network-architecture). + +[optcse.cpp](../../jit/optcse.cpp) - This contains the implementation of `CSE_HeuristicRLHook` used to give the agent the ability to control CSE optimization choices. Specifically `CSE_HeuristicRLHook::GetFeatures` and `CSE_HeuristicRLHook::s_featureNameAndType` are the raw feature building blocks used by `JitCseEnv` to build the observation that the model is trained on. + +[method_context.py](jitml/method_context.py) - This contains the Python classes that mirror the features produced by `CSE_HeuristicRLHook`. This needs to be kept in sync with optcse.cpp. + +## Making Changes + +Typically, most changes should be implemented as gym wrappers. The `JitCseEnv` should be as basic and straight forward as possible. `JitCseModel.train` provides a `wrappers` parameter that you can use to pass in your custom wrappers to test your changes. + +## Testing Changes + +Use Tensorboard to see live updates on how training is going: + +``` bash +tensorboard --host 0.0.0.0 --logdir=./model_output_path/ +``` + +Both A2C and PPO provide extra metrics on the Tensorboard to see if the model is properly training (DQN does not yet have this). Here is an example of successful training (blue) vs a model that did not train (red): + +![Tensorboard](img/training.png) + +Typically, the `rollout/ep_rew_mean`, `results/vs_heuristic`, and `results/vs_no_cse` metrics should all trend upwards over time from a lower value if the model is learning. + +Once you see that a model is training successfully, use `evaluate.py` to see how much better or worse it is over the baseline. + +**NOTE:** The `results/` metrics are a rolling average of comparisons versus baseline since the last metric datapoint was emitted. This metric crossing 0 into the positive does not necessarily mean that the model is performing better than the baseline heuristic in a general sense. Only that it did better on the small subset of training functions it just recently attempted to optimize. Whether or not the model actually performs better than baseline is left to `evaluate.py` after it is finished training. + +## SuperPMI + +SuperPMI is used to do the work of JIT'ing functions. You do not need to use it directly. However, if you need to test the JIT'ing of a method: + +```bash +superpmi libclrjit.so -v q -streaming stdin {mch} +``` + +Then use the format `[method_id]!JitMetrics=1!Var1=Value1!Var2=Value2` to jit methods. For example: + +``` +123!JitMetrics=1 <= JIT the method in the normal way +123!JitMetrics=1!JitRLHook=1 <= Use the reinforcement learning hook + +123!JitMetrics=1!JitRLHook=1!JitRLHookCSEDecisions=2,3,1 <= Enable CSE 2, 3, then 1 when JIT'ing +``` + +## Pylint + +Please run `pylint *py jitml/` before checkin and clean up any warnings (no need to run it on the tests). It's ok to silence warnings with `#pylint disable` if it makes more sense to do that than clean up what it's complaining about. + diff --git a/src/coreclr/scripts/cse_ml/requirements.txt b/src/coreclr/scripts/cse_ml/requirements.txt new file mode 100644 index 00000000000000..8b0c317a47992d --- /dev/null +++ b/src/coreclr/scripts/cse_ml/requirements.txt @@ -0,0 +1,7 @@ +gymnasium +numpy +pydantic +pandas +stable_baselines3 +stable-baselines3[extra] +torch diff --git a/src/coreclr/scripts/cse_ml/scripts/add_jitml_path.py b/src/coreclr/scripts/cse_ml/scripts/add_jitml_path.py new file mode 100644 index 00000000000000..fcc145e0217e1e --- /dev/null +++ b/src/coreclr/scripts/cse_ml/scripts/add_jitml_path.py @@ -0,0 +1,6 @@ +"""Adds jitml to the path for imports. Import add_jitml_path before any jitml.""" + +import sys +import os + +sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) diff --git a/src/coreclr/scripts/cse_ml/scripts/calculate_feature_norm.py b/src/coreclr/scripts/cse_ml/scripts/calculate_feature_norm.py new file mode 100755 index 00000000000000..979a587b6e205f --- /dev/null +++ b/src/coreclr/scripts/cse_ml/scripts/calculate_feature_norm.py @@ -0,0 +1,128 @@ +#!/usr/bin/python +"""Uses a single .mch file to calculate feature normalization to scale all features to [0, 1].""" + +import argparse +import os + +import numpy as np +import tqdm +from pandas import DataFrame, read_csv + +import add_jitml_path as _ # pylint: disable=import-error +from jitml import SuperPmiContext, JitCseEnv +from train import validate_core_root + +def get_feature_data(ctx : SuperPmiContext, save=True) -> DataFrame: + """Returns a DataFrame of all features in a .mch file.""" + + csv_file = ctx.mch + ".features.csv" + if save and os.path.exists(csv_file): + return read_csv(csv_file) + + feature_data = {} + column_names = JitCseEnv.observation_columns + for col in column_names: + feature_data[col] = [] + + ctx.resplit_data(0) + with ctx.create_superpmi() as superpmi: + for m_id in tqdm.tqdm(ctx.training_methods): + method = superpmi.jit_method(m_id, JitMetrics=1, JitRLHook=1) + if method is None: + continue + + observation = JitCseEnv.get_observation(method, fill=False) + for features in observation: + for c, name in enumerate(column_names): + value = features[c] + feature_data[name].append(value) + + if save: + df = DataFrame(feature_data) + df.to_csv(csv_file, index=False) + + return df + +def get_scaling(df : DataFrame): + """Calculate scaling.""" + df = df.copy() + + # Heuristic for using log1p: + # If the standard deviation is greater than 1000 and the max is greater than 10000, use log1p. + std_over_1000 = (df.std() > 100) & (df.max() > 10000) + use_log1p = std_over_1000.values + + # apply log1p to columns that need it + log1p_cols = df.columns[use_log1p] + for col in log1p_cols: + df[col] = np.log1p(df[col]) + + # calculate the scaling after log1p + diff = df.max() - df.min() + subtract = [0] * len(diff) + scale = [1] * len(diff) + + for i, col in enumerate(df.columns): + if df[col].max() == df[col].min(): + continue + + diff_col = diff[col] + subtract[i] = df[col].min() + scale[i] = 1 / diff_col + + return subtract, scale, list(use_log1p) + +def _print_stats(df): + stats = { + "min": df.min(), + "max": df.max(), + "avg": df.mean(), + "std": df.std(), + } + df_stats = DataFrame(stats) + print(df_stats) + +def parse_args(): + """Parses the command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("mch", help="The mch file to calculate feature normalization on.") + parser.add_argument("--core_root", default=None, help="The core_root directory.") + + args = parser.parse_args() + args.core_root = validate_core_root(args.core_root) + return parser.parse_args() + +def main(args): + """Entry point for the script.""" + file_path = args.mch + ".json" + if os.path.exists(file_path): + ctx = SuperPmiContext.load(file_path) + + else: + ctx = SuperPmiContext(core_root=args.core_root, mch=args.mch) + ctx.find_methods_and_split(0.1) + ctx.save(file_path) + + features = get_feature_data(ctx) + print("Unnormalized feature statistics:") + _print_stats(features) + print() + + # Normalize + subtract, scale, use_log1p = get_scaling(features) + for i, col in enumerate(features.columns): + if use_log1p[i]: + features[col] = np.log1p(features[col]) + + features = (features - subtract) * scale + + print("Normalized feature statistics:") + _print_stats(features) + + print() + print(f"subract = {repr(subtract)}") + print(f"scales = {repr(scale)}") + print(f"use_log1p = {repr(use_log1p)}") + +if __name__ == "__main__": + main(parse_args()) diff --git a/src/coreclr/scripts/cse_ml/train.py b/src/coreclr/scripts/cse_ml/train.py new file mode 100755 index 00000000000000..3f1f8d1aacd61d --- /dev/null +++ b/src/coreclr/scripts/cse_ml/train.py @@ -0,0 +1,70 @@ +#!/usr/bin/python + +"""Trains JIT reinforcment learning on superpmi data.""" +import os +import argparse + +from jitml import SuperPmiContext, JitCseModel, OptimalCseWrapper, NormalizeFeaturesWrapper, split_for_cse + +def validate_core_root(core_root): + """Validates and returns the core_root directory.""" + core_root = core_root or os.environ.get("CORE_ROOT", None) + if core_root is None: + raise ValueError("--core_root must be specified or set as the environment variable CORE_ROOT.") + + return core_root + +def parse_args(): + """usage: train.py [-h] [--core_root CORE_ROOT] [--parallel n] [--iterations i] model_path mch""" + parser = argparse.ArgumentParser() + parser.add_argument("model_path", help="The directory to save the model to.") + parser.add_argument("mch", help="The mch file of functions to train on.") + parser.add_argument("--core_root", default=None, help="The coreclr root directory.") + parser.add_argument("--parallel", type=int, default=None, help="The number of parallel environments to use.") + parser.add_argument("--iterations", type=int, default=None, help="The number of iterations to train for.") + parser.add_argument("--algorithm", default="PPO", help="The algorithm to use. (default: PPO)") + parser.add_argument("--test-percent", type=float, default=0.1, + help="The percentage of data to use for testing. (default: 0.1)") + parser.add_argument("--reward-optimal-cse", action='store_true', help="Use smarter rewards. (default: False)") + parser.add_argument("--normalize-features", action='store_true', help="Normalize features. (default: False)") + + args = parser.parse_args() + args.core_root = validate_core_root(args.core_root) + return args + +def main(args): + """Main entry point.""" + output_dir = args.model_path + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + # Load or create the superpmi context. + spmi_file = args.mch + ".json" + if os.path.exists(spmi_file): + ctx = SuperPmiContext.load(spmi_file) + else: + print(f"Creating SuperPmiContext '{spmi_file}', this may take several minutes...") + ctx = SuperPmiContext.create_from_mch(args.mch, args.core_root) + ctx.save(spmi_file) + + test_methods, training_methods = split_for_cse(ctx.methods, 0.1) + print(f"Training with {len(training_methods)} methods, holding back {len(test_methods)} for testing.") + + # Define our own environment (with wrappers) if requested. + + # Train the model. + rl = JitCseModel(args.algorithm) + + wrappers = [] + if args.reward_optimal_cse: + wrappers.append(OptimalCseWrapper) + + if args.normalize_features: + wrappers.append(NormalizeFeaturesWrapper) + + iterations = args.iterations if args.iterations is not None else 1_000_000 + path = rl.train(ctx, training_methods, output_dir, iterations=iterations, parallel=args.parallel, wrappers=wrappers) + print(f"Model saved to: {path}") + +if __name__ == "__main__": + main(parse_args())