Skip to content

Commit b228037

Browse files
Lunderbergtqchen
authored andcommitted
Expose attrs argument of "ir.IRModule" to Rust bindings
1 parent ff5118f commit b228037

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

rust/tvm/src/ir/module.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use crate::runtime::array::Array;
2828
use crate::runtime::function::Result;
2929
use crate::runtime::map::Map;
3030
use crate::runtime::string::String as TVMString;
31-
use crate::runtime::{external, IsObjectRef, Object};
31+
use crate::runtime::{external, IsObjectRef, Object, ObjectRef};
3232

3333
use super::expr::GlobalVar;
3434
use super::function::BaseFunc;
@@ -62,7 +62,7 @@ external! {
6262
#[name("relay.parser.ParseExpr")]
6363
fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule;
6464
#[name("ir.IRModule")]
65-
fn module_new(funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> IRModule;
65+
fn module_new(funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>, attrs: Map<TVMString, ObjectRef>) -> IRModule;
6666
// Module methods
6767
#[name("ir.Module_Add")]
6868
fn module_add(module: IRModule, type_name: GlobalVar, expr: BaseFunc, update: bool) -> IRModule;
@@ -99,18 +99,24 @@ external! {
9999
// Note: we don't expose update here as update is going to be removed.
100100

101101
impl IRModule {
102-
pub fn new<'a, F, T>(funcs: F, types: T) -> Result<IRModule>
102+
pub fn new<'a, F, T, A>(funcs: F, types: T, attrs: A) -> Result<IRModule>
103103
where
104104
F: IntoIterator<Item = (&'a GlobalVar, &'a BaseFunc)>,
105105
T: IntoIterator<Item = (&'a GlobalTypeVar, &'a TypeData)>,
106+
A: IntoIterator<Item = (&'a TVMString, &'a ObjectRef)>,
106107
{
107-
module_new(Map::from_iter(funcs), Map::from_iter(types))
108+
module_new(
109+
Map::from_iter(funcs),
110+
Map::from_iter(types),
111+
Map::from_iter(attrs),
112+
)
108113
}
109114

110115
pub fn empty() -> Result<IRModule> {
111116
let funcs = HashMap::<GlobalVar, BaseFunc>::new();
112117
let types = HashMap::<GlobalTypeVar, TypeData>::new();
113-
IRModule::new(funcs.iter(), types.iter())
118+
let attrs = HashMap::<TVMString, ObjectRef>::new();
119+
IRModule::new(funcs.iter(), types.iter(), attrs.iter())
114120
}
115121

116122
pub fn parse<N, S>(file_name: N, source: S) -> Result<IRModule>

src/ir/module.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,21 @@ TVM_REGISTER_NODE_TYPE(IRModuleNode);
383383

384384
TVM_REGISTER_GLOBAL("ir.IRModule")
385385
.set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs, tvm::Map<GlobalTypeVar, TypeData> types,
386-
tvm::DictAttrs attrs) { return IRModule(funcs, types, {}, {}, attrs); });
386+
tvm::ObjectRef attrs) {
387+
auto dict_attrs = [&attrs]() {
388+
if (!attrs.defined()) {
389+
return DictAttrs();
390+
} else if (auto* as_dict_attrs = attrs.as<tvm::DictAttrsNode>()) {
391+
return GetRef<tvm::DictAttrs>(as_dict_attrs);
392+
} else if (attrs.as<tvm::MapNode>()) {
393+
return tvm::DictAttrs(Downcast<Map<String, ObjectRef>>(attrs));
394+
} else {
395+
LOG(FATAL) << "Expected attrs argument to be either DictAttrs or Map<String,ObjectRef>";
396+
}
397+
}();
398+
399+
return IRModule(funcs, types, {}, {}, dict_attrs);
400+
});
387401

388402
TVM_REGISTER_GLOBAL("ir.Module_Add")
389403
.set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule {

0 commit comments

Comments
 (0)