#!/usr/bin/env python # coding: utf-8 # In[1]: import mxnet as mx import numpy as np import gzip, struct import argparse import os # In[2]: def download_file(url, local_fname=None, force_write=False): # requests is not default installed import requests if local_fname is None: local_fname = url.split('/')[-1] if not force_write and os.path.exists(local_fname): return local_fname dir_name = os.path.dirname(local_fname) if dir_name != "": if not os.path.exists(dir_name): try: # try to create the directory if it doesn't exists os.makedirs(dir_name) except OSError as exc: if exc.errno != errno.EEXIST: raise r = requests.get(url, stream=True) assert r.status_code == 200, "failed to open %s" % url with open(local_fname, 'wb') as f: for chunk in r.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks f.write(chunk) return local_fname # In[3]: def read_data(label, image): """ download and read data into numpy """ base_url = 'http://yann.lecun.com/exdb/mnist/' with gzip.open(download_file(base_url+label, os.path.join('data',label))) as flbl: magic, num = struct.unpack(">II", flbl.read(8)) label = np.fromstring(flbl.read(), dtype=np.int8) with gzip.open(download_file(base_url+image, os.path.join('data',image)), 'rb') as fimg: magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16)) image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols) return (label, image) def to4d(img): """ reshape to 4D arrays """ return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255 def get_mnist_iter(args): """ create data iterator with NDArrayIter """ (train_lbl, train_img) = read_data( 'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz') (val_lbl, val_img) = read_data( 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz') train = mx.io.NDArrayIter( to4d(train_img), train_lbl, args.batch_size, shuffle=True) val = mx.io.NDArrayIter( to4d(val_img), val_lbl, args.batch_size) return (train, val) # In[4]: def get_symbol(num_classes=10, **kwargs): data = mx.symbol.Variable('data') data = mx.sym.Flatten(data=data) fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes) mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax') return mlp # In[5]: def _load_model(args, rank=0): if 'load_epoch' not in args or args.load_epoch is None: return (None, None, None) assert args.model_prefix is not None model_prefix = args.model_prefix if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)): model_prefix += "-%d" % (rank) sym, arg_params, aux_params = mx.model.load_checkpoint( model_prefix, args.load_epoch) logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch) return (sym, arg_params, aux_params) def _save_model(args, rank=0): if args.model_prefix is None: return None dst_dir = os.path.dirname(args.model_prefix) try: os.path.isdir(dst_dir) except FileNotFoundError: os.mkdir(dst_dir) return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % ( args.model_prefix, rank)) # In[15]: def fit(args, network, data_loader, **kwargs): """ train a model args : argparse returns network : the symbol definition of the nerual network data_loader : function that returns the train and val data iterators """ # data iterators (train, val) = data_loader(args) # save model checkpoint = _save_model(args, 0) # devices for training devs = mx.cpu() if args.gpus is None or args.gpus == "" else [ mx.gpu(int(i)) for i in args.gpus.split(',')] # create model model = mx.mod.Module( context=devs, symbol=network ) optimizer_params = { 'learning_rate': args.lr, 'wd': args.wd, 'multi_precision': True} # Only a limited number of optimizers have 'momentum' property has_momentum = {'sgd', 'dcasgd', 'nag'} if args.optimizer in has_momentum: optimizer_params['momentum'] = args.mom monitor = mx.mon.Monitor( args.monitor, pattern=".*") if args.monitor > 0 else None # A limited number of optimizers have a warmup period has_warmup = {'lbsgd', 'lbnag'} if args.optimizer in has_warmup: if 'dist' in args.kv_store: nworkers = kv.num_workers else: nworkers = 1 epoch_size = args.num_examples / args.batch_size / nworkers if epoch_size < 1: epoch_size = 1 macrobatch_size = args.macrobatch_size if macrobatch_size < args.batch_size * nworkers: macrobatch_size = args.batch_size * nworkers #batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999) batch_scale = math.ceil( float(macrobatch_size) / args.batch_size / nworkers) optimizer_params['updates_per_epoch'] = epoch_size optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch else 0 optimizer_params['batch_scale'] = batch_scale optimizer_params['warmup_strategy'] = args.warmup_strategy optimizer_params['warmup_epochs'] = args.warmup_epochs optimizer_params['num_epochs'] = args.num_epochs initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) # evaluation metrices eval_metrics = ['accuracy'] supported_loss = ['ce', 'nll_loss'] if len(args.loss) > 0: # ce or nll loss is only applicable to softmax output loss_type_list = args.loss.split(',') if 'softmax_output' in network.list_outputs(): for loss_type in loss_type_list: loss_type = loss_type.strip() if loss_type == 'nll': loss_type = 'nll_loss' if loss_type not in supported_loss: logging.warning(loss_type + ' is not an valid loss type, only cross-entropy or ' 'negative likelihood loss is supported!') else: eval_metrics.append(mx.metric.create(loss_type)) else: logging.warning("The output is not softmax_output, loss argument will be skipped!") # callbacks that run after each batch batch_end_callbacks = [mx.callback.Speedometer( args.batch_size, args.disp_batches)] if 'batch_end_callback' in kwargs: cbs = kwargs['batch_end_callback'] batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs] # run model.fit(train, begin_epoch=args.load_epoch if args.load_epoch else 0, num_epoch=args.num_epochs, eval_data=val, eval_metric=eval_metrics, kvstore=kv, optimizer=args.optimizer, optimizer_params=optimizer_params, initializer=initializer, arg_params=arg_params, aux_params=aux_params, batch_end_callback=batch_end_callbacks, epoch_end_callback=checkpoint, allow_missing=True, monitor=monitor) # In[ ]: class args: gpus = None batch_size = 64 disp_batches = 100 num_epochs = 20 lr = .05 model_prefix = "minist_mlp" wd = 0.0001 optimizer = 'sgd' mom = 0.9 monitor = 0 loss = '' sym = get_symbol() fit(args, sym, get_mnist_iter) # In[ ]: