Contains the general structure for all models in this library
class AbsModel(nn.Module):
    """
    This class contains the general architecture and functionality to deal with all Models in this library
    contains:
    Tracker -> to visualize the progress
    Prep-input -> to handle the input depending on the dataset smootly
    """

    def __init__(self, args):
        """init the abtract model"""
        super(AbsModel, self).__init__()

        # Setup the input loader
        self.prep = self.select_prep(args.dataset_type)

        # Setup the tracker to visualize the progress
        if args.track:
            self.tracker = Tracker(args)

    def select_prep(self, mode):
        switcher = {
            "MRNet": self.prep_mrnet_input,
            "KneeXray": self.prep_kneexray_input,
        }
        # Get the model_creator
        prep = switcher.get(mode, lambda: "Invalid Dataset Type")
        # create model
        return prep

    def prep_mrnet_input(self, data):
        """
        This function deals with the MRNET input
        data = {"img: x", ...}
        """
        return data["img"]

    def prep_kneexray_input(self, data):
        """
        This function deals with the KneeXray input
        data = [x, y]
        """
        return data[0]

    @torch.no_grad()
    def watch_progress(self, test_data, iteration):
        """Outsourced to Tracker"""
        self.tracker.track_progress(self, test_data, iteration)

class AbsModel[source]

AbsModel(args) :: Module

This class contains the general architecture and functionality to deal with all Models in this library contains: Tracker -> to visualize the progress Prep-input -> to handle the input depending on the dataset smootly