Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions deepem/test/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import imp
import numpy as np
import os
from types import SimpleNamespace
Expand All @@ -11,7 +10,7 @@

def load_model(opt):
# Create a model.
mod = imp.load_source('model', opt.model)
mod = py_utils.load_module('model', opt.model)
if opt.onnx:
model = OnnxModel(mod.create_model(opt), opt)
else:
Expand Down
9 changes: 5 additions & 4 deletions deepem/train/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import imp
from deepem.utils.py_utils import load_module

import numpy as np

import torch
Expand Down Expand Up @@ -45,20 +46,20 @@ def requires_grad(self, key):
def build(self, opt, data, is_train, prob):
# Data augmentation
if opt.augment:
mod = imp.load_source('augment', opt.augment)
mod = load_module('augment', opt.augment)
aug = mod.get_augmentation(is_train, **opt.aug_params)
else:
aug = None

# Data sampler
mod = imp.load_source('sampler', opt.sampler)
mod = load_module('sampler', opt.sampler)
spec = mod.get_spec(opt.in_spec, opt.out_spec)
zspecs = opt.zettaset_specs
sampler = mod.Sampler(data, spec, is_train, aug, prob, zspecs)

# Sample modifier
if opt.modifier:
mod = imp.load_source('modifier', opt.modifier)
mod = load_module('modifier', opt.modifier)
self.modifier = mod.Modifier(**opt.modifier_kwargs)
else:
def default_modifier(x, **kwargs):
Expand Down
6 changes: 4 additions & 2 deletions deepem/train/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import imp
import os
import glob

Expand All @@ -9,6 +8,7 @@
from deepem.train.data import Data
from deepem.train.model import Model, AmpModel
from deepem.loss.utils import BinaryWeightBalancer
from deepem.utils.py_utils import load_module


def get_criteria(opt):
Expand Down Expand Up @@ -59,7 +59,9 @@ def get_criteria(opt):

def load_model(opt):
# Create a model.
mod = imp.load_source('model', opt.model)

mod = load_module("model", opt.model)

if opt.mixed_precision:
model = AmpModel(mod.create_model(opt), get_criteria(opt), opt)
else:
Expand Down
10 changes: 10 additions & 0 deletions deepem/utils/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@

from sklearn.decomposition import PCA

import importlib
import importlib.util


def load_module(name, path):
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module


def dict2tuple(d):
return namedtuple('GenericDict', d.keys())(**d)
Expand Down