|
@@ -0,0 +1,258 @@
|
|
|
+#!/usr/bin/env python
|
|
|
+# coding: utf-8
|
|
|
+
|
|
|
+# In[1]:
|
|
|
+
|
|
|
+
|
|
|
+import mxnet as mx
|
|
|
+import numpy as np
|
|
|
+import gzip, struct
|
|
|
+import argparse
|
|
|
+import os
|
|
|
+
|
|
|
+
|
|
|
+# In[2]:
|
|
|
+
|
|
|
+
|
|
|
+def download_file(url, local_fname=None, force_write=False):
|
|
|
+ # requests is not default installed
|
|
|
+ import requests
|
|
|
+ if local_fname is None:
|
|
|
+ local_fname = url.split('/')[-1]
|
|
|
+ if not force_write and os.path.exists(local_fname):
|
|
|
+ return local_fname
|
|
|
+
|
|
|
+ dir_name = os.path.dirname(local_fname)
|
|
|
+
|
|
|
+ if dir_name != "":
|
|
|
+ if not os.path.exists(dir_name):
|
|
|
+ try: # try to create the directory if it doesn't exists
|
|
|
+ os.makedirs(dir_name)
|
|
|
+ except OSError as exc:
|
|
|
+ if exc.errno != errno.EEXIST:
|
|
|
+ raise
|
|
|
+
|
|
|
+ r = requests.get(url, stream=True)
|
|
|
+ assert r.status_code == 200, "failed to open %s" % url
|
|
|
+ with open(local_fname, 'wb') as f:
|
|
|
+ for chunk in r.iter_content(chunk_size=1024):
|
|
|
+ if chunk: # filter out keep-alive new chunks
|
|
|
+ f.write(chunk)
|
|
|
+ return local_fname
|
|
|
+
|
|
|
+
|
|
|
+# In[3]:
|
|
|
+
|
|
|
+
|
|
|
+def read_data(label, image):
|
|
|
+ """
|
|
|
+ download and read data into numpy
|
|
|
+ """
|
|
|
+ base_url = 'http://yann.lecun.com/exdb/mnist/'
|
|
|
+ with gzip.open(download_file(base_url+label, os.path.join('data',label))) as flbl:
|
|
|
+ magic, num = struct.unpack(">II", flbl.read(8))
|
|
|
+ label = np.fromstring(flbl.read(), dtype=np.int8)
|
|
|
+ with gzip.open(download_file(base_url+image, os.path.join('data',image)), 'rb') as fimg:
|
|
|
+ magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
|
|
|
+ image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
|
|
|
+ return (label, image)
|
|
|
+
|
|
|
+
|
|
|
+def to4d(img):
|
|
|
+ """
|
|
|
+ reshape to 4D arrays
|
|
|
+ """
|
|
|
+ return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255
|
|
|
+
|
|
|
+def get_mnist_iter(args):
|
|
|
+ """
|
|
|
+ create data iterator with NDArrayIter
|
|
|
+ """
|
|
|
+ (train_lbl, train_img) = read_data(
|
|
|
+ 'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz')
|
|
|
+ (val_lbl, val_img) = read_data(
|
|
|
+ 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz')
|
|
|
+ train = mx.io.NDArrayIter(
|
|
|
+ to4d(train_img), train_lbl, args.batch_size, shuffle=True)
|
|
|
+ val = mx.io.NDArrayIter(
|
|
|
+ to4d(val_img), val_lbl, args.batch_size)
|
|
|
+ return (train, val)
|
|
|
+
|
|
|
+
|
|
|
+# In[4]:
|
|
|
+
|
|
|
+
|
|
|
+def get_symbol(num_classes=10, **kwargs):
|
|
|
+ data = mx.symbol.Variable('data')
|
|
|
+ data = mx.sym.Flatten(data=data)
|
|
|
+ fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
|
|
|
+ act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
|
|
|
+ fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
|
|
|
+ act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
|
|
|
+ fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
|
|
|
+ mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
|
|
|
+ return mlp
|
|
|
+
|
|
|
+
|
|
|
+# In[5]:
|
|
|
+
|
|
|
+
|
|
|
+def _load_model(args, rank=0):
|
|
|
+ if 'load_epoch' not in args or args.load_epoch is None:
|
|
|
+ return (None, None, None)
|
|
|
+ assert args.model_prefix is not None
|
|
|
+ model_prefix = args.model_prefix
|
|
|
+ if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)):
|
|
|
+ model_prefix += "-%d" % (rank)
|
|
|
+ sym, arg_params, aux_params = mx.model.load_checkpoint(
|
|
|
+ model_prefix, args.load_epoch)
|
|
|
+ logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch)
|
|
|
+ return (sym, arg_params, aux_params)
|
|
|
+
|
|
|
+
|
|
|
+def _save_model(args, rank=0):
|
|
|
+ if args.model_prefix is None:
|
|
|
+ return None
|
|
|
+ dst_dir = os.path.dirname(args.model_prefix)
|
|
|
+ try:
|
|
|
+ os.path.isdir(dst_dir)
|
|
|
+ except FileNotFoundError:
|
|
|
+ os.mkdir(dst_dir)
|
|
|
+ return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % (
|
|
|
+ args.model_prefix, rank))
|
|
|
+
|
|
|
+
|
|
|
+# In[15]:
|
|
|
+
|
|
|
+
|
|
|
+def fit(args, network, data_loader, **kwargs):
|
|
|
+ """
|
|
|
+ train a model
|
|
|
+ args : argparse returns
|
|
|
+ network : the symbol definition of the nerual network
|
|
|
+ data_loader : function that returns the train and val data iterators
|
|
|
+ """
|
|
|
+ # data iterators
|
|
|
+ (train, val) = data_loader(args)
|
|
|
+
|
|
|
+ # save model
|
|
|
+ checkpoint = _save_model(args, 0)
|
|
|
+
|
|
|
+ # devices for training
|
|
|
+ devs = mx.cpu() if args.gpus is None or args.gpus == "" else [
|
|
|
+ mx.gpu(int(i)) for i in args.gpus.split(',')]
|
|
|
+
|
|
|
+ # create model
|
|
|
+ model = mx.mod.Module(
|
|
|
+ context=devs,
|
|
|
+ symbol=network
|
|
|
+ )
|
|
|
+
|
|
|
+ optimizer_params = {
|
|
|
+ 'learning_rate': args.lr,
|
|
|
+ 'wd': args.wd,
|
|
|
+ 'multi_precision': True}
|
|
|
+
|
|
|
+ # Only a limited number of optimizers have 'momentum' property
|
|
|
+ has_momentum = {'sgd', 'dcasgd', 'nag'}
|
|
|
+ if args.optimizer in has_momentum:
|
|
|
+ optimizer_params['momentum'] = args.mom
|
|
|
+
|
|
|
+ monitor = mx.mon.Monitor(
|
|
|
+ args.monitor, pattern=".*") if args.monitor > 0 else None
|
|
|
+
|
|
|
+ # A limited number of optimizers have a warmup period
|
|
|
+ has_warmup = {'lbsgd', 'lbnag'}
|
|
|
+ if args.optimizer in has_warmup:
|
|
|
+ if 'dist' in args.kv_store:
|
|
|
+ nworkers = kv.num_workers
|
|
|
+ else:
|
|
|
+ nworkers = 1
|
|
|
+ epoch_size = args.num_examples / args.batch_size / nworkers
|
|
|
+ if epoch_size < 1:
|
|
|
+ epoch_size = 1
|
|
|
+ macrobatch_size = args.macrobatch_size
|
|
|
+ if macrobatch_size < args.batch_size * nworkers:
|
|
|
+ macrobatch_size = args.batch_size * nworkers
|
|
|
+ #batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999)
|
|
|
+ batch_scale = math.ceil(
|
|
|
+ float(macrobatch_size) / args.batch_size / nworkers)
|
|
|
+ optimizer_params['updates_per_epoch'] = epoch_size
|
|
|
+ optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch else 0
|
|
|
+ optimizer_params['batch_scale'] = batch_scale
|
|
|
+ optimizer_params['warmup_strategy'] = args.warmup_strategy
|
|
|
+ optimizer_params['warmup_epochs'] = args.warmup_epochs
|
|
|
+ optimizer_params['num_epochs'] = args.num_epochs
|
|
|
+
|
|
|
+ initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)
|
|
|
+
|
|
|
+ # evaluation metrices
|
|
|
+ eval_metrics = ['accuracy']
|
|
|
+
|
|
|
+ supported_loss = ['ce', 'nll_loss']
|
|
|
+ if len(args.loss) > 0:
|
|
|
+ # ce or nll loss is only applicable to softmax output
|
|
|
+ loss_type_list = args.loss.split(',')
|
|
|
+ if 'softmax_output' in network.list_outputs():
|
|
|
+ for loss_type in loss_type_list:
|
|
|
+ loss_type = loss_type.strip()
|
|
|
+ if loss_type == 'nll':
|
|
|
+ loss_type = 'nll_loss'
|
|
|
+ if loss_type not in supported_loss:
|
|
|
+ logging.warning(loss_type + ' is not an valid loss type, only cross-entropy or ' 'negative likelihood loss is supported!')
|
|
|
+ else:
|
|
|
+ eval_metrics.append(mx.metric.create(loss_type))
|
|
|
+ else:
|
|
|
+ logging.warning("The output is not softmax_output, loss argument will be skipped!")
|
|
|
+
|
|
|
+ # callbacks that run after each batch
|
|
|
+ batch_end_callbacks = [mx.callback.Speedometer(
|
|
|
+ args.batch_size, args.disp_batches)]
|
|
|
+ if 'batch_end_callback' in kwargs:
|
|
|
+ cbs = kwargs['batch_end_callback']
|
|
|
+ batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs]
|
|
|
+
|
|
|
+ # run
|
|
|
+ model.fit(train,
|
|
|
+ begin_epoch=args.load_epoch if args.load_epoch else 0,
|
|
|
+ num_epoch=args.num_epochs,
|
|
|
+ eval_data=val,
|
|
|
+ eval_metric=eval_metrics,
|
|
|
+ kvstore=kv,
|
|
|
+ optimizer=args.optimizer,
|
|
|
+ optimizer_params=optimizer_params,
|
|
|
+ initializer=initializer,
|
|
|
+ arg_params=arg_params,
|
|
|
+ aux_params=aux_params,
|
|
|
+ batch_end_callback=batch_end_callbacks,
|
|
|
+ epoch_end_callback=checkpoint,
|
|
|
+ allow_missing=True,
|
|
|
+ monitor=monitor)
|
|
|
+
|
|
|
+
|
|
|
+# In[ ]:
|
|
|
+
|
|
|
+
|
|
|
+class args:
|
|
|
+ gpus = None
|
|
|
+ batch_size = 64
|
|
|
+ disp_batches = 100
|
|
|
+ num_epochs = 20
|
|
|
+ lr = .05
|
|
|
+ model_prefix = "minist_mlp"
|
|
|
+ wd = 0.0001
|
|
|
+ optimizer = 'sgd'
|
|
|
+ mom = 0.9
|
|
|
+ monitor = 0
|
|
|
+ loss = ''
|
|
|
+
|
|
|
+
|
|
|
+sym = get_symbol()
|
|
|
+fit(args, sym, get_mnist_iter)
|
|
|
+
|
|
|
+
|
|
|
+# In[ ]:
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|