Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions rust/tvm-graph-rt/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ use tvm_sys::DataType;

#[derive(Debug, Error)]
pub enum GraphFormatError {
#[error("Could not parse graph json")]
Parse(#[from] serde_json::Error),
#[error("Could not parse graph params")]
Params,
#[error("{0} is missing attr: {1}")]
#[error("Failed to parse graph with error: {0}")]
Parse(#[source] serde_json::Error),
#[error("Failed to parse graph parameters with error: {0:?}")]
Params(#[source] Option<nom::Err<(Vec<u8>, nom::error::ErrorKind)>>),
#[error("{0} is missing attribute: {1}")]
MissingAttr(String, String),
#[error("Graph has invalid attr that can't be parsed: {0}")]
InvalidAttr(#[from] std::num::ParseIntError),
#[error("Failed to parse graph attribute '{0}' with error: {1}")]
InvalidAttr(String, #[source] std::num::ParseIntError),
#[error("Missing field: {0}")]
MissingField(&'static str),
#[error("Invalid DLType: {0}")]
Expand Down
49 changes: 34 additions & 15 deletions rust/tvm-graph-rt/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use nom::{
character::complete::{alpha1, digit1},
complete, count, do_parse, length_count, map, named,
number::complete::{le_i32, le_i64, le_u16, le_u32, le_u64, le_u8},
opt, tag, take, tuple,
opt, tag, take, tuple, Err as NomErr,
};
use serde::{Deserialize, Serialize};
use serde_json;
Expand Down Expand Up @@ -121,27 +121,37 @@ impl Node {
.attrs
.as_ref()
.ok_or_else(|| GraphFormatError::MissingAttr(self.name.clone(), "attrs".to_owned()))?;

let func_name = get_node_attr!(self.name, attrs, "func_name")?.to_owned();

let num_outputs = get_node_attr!(self.name, attrs, "num_outputs")?
.parse::<usize>()
.map_err(|error| GraphFormatError::InvalidAttr("num_outputs".to_string(), error))?;

let flatten_data = get_node_attr!(self.name, attrs, "flatten_data")?
.parse::<u8>()
.map(|val| val == 1)
.map_err(|error| GraphFormatError::InvalidAttr("flatten_data".to_string(), error))?;

Ok(NodeAttrs {
func_name: get_node_attr!(self.name, attrs, "func_name")?.to_owned(),
num_outputs: get_node_attr!(self.name, attrs, "num_outputs")?.parse::<usize>()?,
flatten_data: get_node_attr!(self.name, attrs, "flatten_data")?.parse::<u8>()? == 1,
func_name,
num_outputs,
flatten_data,
})
}
}

impl<'a> TryFrom<&'a String> for Graph {
type Error = GraphFormatError;
fn try_from(graph_json: &String) -> Result<Self, GraphFormatError> {
let graph = serde_json::from_str(graph_json)?;
Ok(graph)
serde_json::from_str(graph_json).map_err(|error| GraphFormatError::Parse(error))
}
}

impl<'a> TryFrom<&'a str> for Graph {
type Error = GraphFormatError;
fn try_from(graph_json: &'a str) -> Result<Self, Self::Error> {
let graph = serde_json::from_str(graph_json)?;
Ok(graph)
serde_json::from_str(graph_json).map_err(|error| GraphFormatError::Parse(error))
}
}

Expand Down Expand Up @@ -475,14 +485,23 @@ named! {

/// Loads a param dict saved using `relay.save_param_dict`.
pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>, GraphFormatError> {
if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) {
if remaining_bytes.is_empty() {
Ok(param_dict)
} else {
Err(GraphFormatError::Params)
match parse_param_dict(bytes) {
Ok((remaining_bytes, param_dict)) => {
if remaining_bytes.is_empty() {
Ok(param_dict)
} else {
Err(GraphFormatError::Params(None))
}
}
} else {
Err(GraphFormatError::Params)
Err(error) => Err(match error {
NomErr::Incomplete(error) => GraphFormatError::Params(Some(NomErr::Incomplete(error))),
NomErr::Error((remainder, error_kind)) => {
GraphFormatError::Params(Some(NomErr::Error((remainder.into(), error_kind))))
}
NomErr::Failure((remainder, error_kind)) => {
GraphFormatError::Params(Some(NomErr::Failure((remainder.into(), error_kind))))
}
}),
}
}

Expand Down
10 changes: 5 additions & 5 deletions rust/tvm-graph-rt/src/module/syslib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

use std::{
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex,
collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::RwLock,
};

use lazy_static::lazy_static;
Expand All @@ -35,14 +35,14 @@ extern "C" {
}

lazy_static! {
static ref SYSTEM_LIB_FUNCTIONS: Mutex<HashMap<String, &'static (dyn PackedFunc)>> =
Mutex::new(HashMap::new());
static ref SYSTEM_LIB_FUNCTIONS: RwLock<HashMap<String, &'static (dyn PackedFunc)>> =
RwLock::new(HashMap::new());
}

impl Module for SystemLibModule {
fn get_function<S: AsRef<str>>(&self, name: S) -> Option<&(dyn PackedFunc)> {
SYSTEM_LIB_FUNCTIONS
.lock()
.read()
.unwrap()
.get(name.as_ref())
.copied()
Expand All @@ -65,7 +65,7 @@ pub extern "C" fn TVMBackendRegisterSystemLibSymbol(
func: BackendPackedCFunc,
) -> i32 {
let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() };
SYSTEM_LIB_FUNCTIONS.lock().unwrap().insert(
SYSTEM_LIB_FUNCTIONS.write().unwrap().insert(
name.to_string(),
&*Box::leak(super::wrap_backend_packed_func(name.to_string(), func)),
);
Expand Down