Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions mlir/include/mlir/Dialect/Transform/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,35 @@ def InferEffectsPass : Pass<"transform-infer-effects"> {
}];
}

def PreloadLibraryPass : Pass<"transform-preload-library"> {
let summary = "preload transform dialect library";
let description = [{
This pass preloads a transform library and makes it available to a subsequent
transform interpreter passes. The preloading occurs into the Transform
dialect and thus provides very limited functionality that does not scale.

Warning: Only a single such pass should exist for a given MLIR context.
This is a temporary solution until a resource-based solution is available.
TODO: use a resource blob.
}];
let options = [
ListOption<"transformLibraryPaths", "transform-library-paths", "std::string",
"Optional paths to files with modules that should be merged into the "
"transform module to provide the definitions of external named sequences.">
];
}

def InterpreterPass : Pass<"transform-interpreter"> {
let summary = "transform dialect interpreter";
let description = [{
This pass runs the transform dialect interpreter and applies the named
sequence transformation specified by the provided name (defaults to
`__transform_main`).
}];
let options = [
Option<"entryPoint", "entry-point", "std::string",
/*default=*/[{"__transform_main"}],
"Entry point of the pass pipeline.">,
];
}
#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_PASSES
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,31 @@ class Region;

namespace transform {
namespace detail {

/// Expands the given list of `paths` to a list of `.mlir` files.
///
/// Each entry in `paths` may either be a regular file, in which case it ends up
/// in the result list, or a directory, in which case all (regular) `.mlir`
/// files in that directory are added. Any other file types lead to a failure.
LogicalResult expandPathsToMLIRFiles(ArrayRef<std::string> &paths,
MLIRContext *context,
SmallVectorImpl<std::string> &fileNames);

/// Utility to parse and verify the content of a `transformFileName` MLIR file
/// containing a transform dialect specification.
LogicalResult
parseTransformModuleFromFile(MLIRContext *context,
llvm::StringRef transformFileName,
OwningOpRef<ModuleOp> &transformModule);

/// Utility to parse, verify, aggregate and link the content of all mlir files
/// nested under `transformLibraryPaths` and containing transform dialect
/// specifications.
LogicalResult
assembleTransformLibraryFromPaths(MLIRContext *context,
ArrayRef<std::string> transformLibraryPaths,
OwningOpRef<ModuleOp> &transformModule);

/// Utility to load a transform interpreter `module` from a module that has
/// already been preloaded in the context.
/// This mode is useful in cases where explicit parsing of a transform library
Expand Down
10 changes: 7 additions & 3 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,13 @@ ArrayRef<Operation *>
transform::TransformState::getPayloadOpsView(Value value) const {
const TransformOpMapping &operationMapping = getMapping(value).direct;
auto iter = operationMapping.find(value);
assert(
iter != operationMapping.end() &&
"cannot find mapping for payload handle (param/value handle provided?)");

if (iter == operationMapping.end()) {
value.dump();
assert(false &&
"cannot find mapping for payload handle (param/value handle "
"provided?)");
}
return iter->getSecond();
}

Expand Down
16 changes: 14 additions & 2 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1761,8 +1761,20 @@ DiagnosedSilenceableFailure
transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
// Nothing to do here.
return DiagnosedSilenceableFailure::success();
if (isExternal())
return emitDefiniteFailure() << "unresolved external named sequence";

// Map the entry block argument to the list of operations.
// Note: this is the same implementation as PossibleTopLevelTransformOp but
// without attaching the interface / trait since that is tailored to a
// dangling top-level op that does not get "called".
auto scope = state.make_region_scope(getBody());
if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
state, this->getOperation(), getBody())))
return DiagnosedSilenceableFailure::definiteFailure();

return applySequenceBlock(getBody().front(),
FailurePropagationMode::Propagate, state, results);
}

void transform::NamedSequenceOp::getEffects(
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
add_mlir_dialect_library(MLIRTransformDialectTransforms
CheckUses.cpp
InferEffects.cpp
InterpreterPass.cpp
PreloadLibraryPass.cpp
TransformInterpreterPassBase.cpp
TransformInterpreterUtils.cpp

Expand Down
56 changes: 56 additions & 0 deletions mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
//===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Transforms/Passes.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"

using namespace mlir;

namespace mlir {
namespace transform {
#define GEN_PASS_DEF_INTERPRETERPASS
#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
} // namespace transform
} // namespace mlir

namespace {
class InterpreterPass
: public transform::impl::InterpreterPassBase<InterpreterPass> {
public:
using Base::Base;

LogicalResult initialize(MLIRContext *context) override {
// TODO: use a resource blob.
ModuleOp transformModule =
transform::detail::getPreloadedTransformModule(context);
if (transformModule) {
sharedTransformModule =
std::make_shared<OwningOpRef<ModuleOp>>(transformModule.clone());
}
return success();
}

void runOnOperation() override {
if (failed(transform::applyTransformNamedSequence(
getOperation(), sharedTransformModule->get(),
options.enableExpensiveChecks(true), entryPoint)))
return signalPassFailure();
}

private:
/// Transform interpreter options.
transform::TransformOptions options;

/// The separate transform module to be used for transformations, shared
/// across multiple instances of the pass if it is applied in parallel to
/// avoid potentially expensive cloning. MUST NOT be modified after the pass
/// has been initialized.
std::shared_ptr<OwningOpRef<ModuleOp>> sharedTransformModule = nullptr;
};
} // namespace
41 changes: 41 additions & 0 deletions mlir/lib/Dialect/Transform/Transforms/PreloadLibraryPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//===- PreloadLibraryPass.cpp - Pass to preload a transform library -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Transforms/Passes.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"

using namespace mlir;

namespace mlir {
namespace transform {
#define GEN_PASS_DEF_PRELOADLIBRARYPASS
#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
} // namespace transform
} // namespace mlir

namespace {
class PreloadLibraryPass
: public transform::impl::PreloadLibraryPassBase<PreloadLibraryPass> {
public:
using Base::Base;

LogicalResult initialize(MLIRContext *context) override {
OwningOpRef<ModuleOp> mergedParsedLibraries;
if (failed(transform::detail::assembleTransformLibraryFromPaths(
context, transformLibraryPaths, mergedParsedLibraries)))
return failure();
// TODO: use a resource blob.
auto *dialect = context->getOrLoadDialect<transform::TransformDialect>();
dialect->registerLibraryModule(std::move(mergedParsedLibraries));
return success();
}

void runOnOperation() override {}
};
} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -357,59 +357,6 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
extraMappings, options);
}

/// Expands the given list of `paths` to a list of `.mlir` files.
///
/// Each entry in `paths` may either be a regular file, in which case it ends up
/// in the result list, or a directory, in which case all (regular) `.mlir`
/// files in that directory are added. Any other file types lead to a failure.
static LogicalResult
expandPathsToMLIRFiles(ArrayRef<std::string> &paths, MLIRContext *const context,
SmallVectorImpl<std::string> &fileNames) {
for (const std::string &path : paths) {
auto loc = FileLineColLoc::get(context, path, 0, 0);

if (llvm::sys::fs::is_regular_file(path)) {
LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
fileNames.push_back(path);
continue;
}

if (!llvm::sys::fs::is_directory(path)) {
return emitError(loc)
<< "'" << path << "' is neither a file nor a directory";
}

LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");

std::error_code ec;
for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
it != itEnd && !ec; it.increment(ec)) {
const std::string &fileName = it->path();

if (it->type() != llvm::sys::fs::file_type::regular_file) {
LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName
<< "'\n");
continue;
}

if (!StringRef(fileName).endswith(".mlir")) {
LLVM_DEBUG(DBGS() << " Skipping '" << fileName
<< "' because it does not end with '.mlir'\n");
continue;
}

LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n");
fileNames.push_back(fileName);
}

if (ec)
return emitError(loc) << "error while opening files in '" << path
<< "': " << ec.message();
}

return success();
}

LogicalResult transform::detail::interpreterBaseInitializeImpl(
MLIRContext *context, StringRef transformFileName,
ArrayRef<std::string> transformLibraryPaths,
Expand Down
Loading