Contains the general structure for all models in this library
class AbsModel(nn.Module):
"""
This class contains the general architecture and functionality to deal with all Models in this library
contains:
Tracker -> to visualize the progress
Prep-input -> to handle the input depending on the dataset smootly
"""
def __init__(self, args):
"""init the abtract model"""
super(AbsModel, self).__init__()
# Setup the input loader
self.prep = self.select_prep(args.dataset_type)
# Setup the tracker to visualize the progress
if args.track:
self.tracker = Tracker(args)
def select_prep(self, mode):
switcher = {
"MRNet": self.prep_mrnet_input,
"KneeXray": self.prep_kneexray_input,
}
# Get the model_creator
prep = switcher.get(mode, lambda: "Invalid Dataset Type")
# create model
return prep
def prep_mrnet_input(self, data):
"""
This function deals with the MRNET input
data = {"img: x", ...}
"""
return data["img"]
def prep_kneexray_input(self, data):
"""
This function deals with the KneeXray input
data = [x, y]
"""
return data[0]
@torch.no_grad()
def watch_progress(self, test_data, iteration):
"""Outsourced to Tracker"""
self.tracker.track_progress(self, test_data, iteration)