123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258 |
- import mxnet as mx
- import numpy as np
- import gzip, struct
- import argparse
- import os
- def download_file(url, local_fname=None, force_write=False):
-
- import requests
- if local_fname is None:
- local_fname = url.split('/')[-1]
- if not force_write and os.path.exists(local_fname):
- return local_fname
- dir_name = os.path.dirname(local_fname)
- if dir_name != "":
- if not os.path.exists(dir_name):
- try:
- os.makedirs(dir_name)
- except OSError as exc:
- if exc.errno != errno.EEXIST:
- raise
- r = requests.get(url, stream=True)
- assert r.status_code == 200, "failed to open %s" % url
- with open(local_fname, 'wb') as f:
- for chunk in r.iter_content(chunk_size=1024):
- if chunk:
- f.write(chunk)
- return local_fname
- def read_data(label, image):
- """
- download and read data into numpy
- """
- base_url = 'http://yann.lecun.com/exdb/mnist/'
- with gzip.open(download_file(base_url+label, os.path.join('data',label))) as flbl:
- magic, num = struct.unpack(">II", flbl.read(8))
- label = np.fromstring(flbl.read(), dtype=np.int8)
- with gzip.open(download_file(base_url+image, os.path.join('data',image)), 'rb') as fimg:
- magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
- image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
- return (label, image)
- def to4d(img):
- """
- reshape to 4D arrays
- """
- return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255
- def get_mnist_iter(args):
- """
- create data iterator with NDArrayIter
- """
- (train_lbl, train_img) = read_data(
- 'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz')
- (val_lbl, val_img) = read_data(
- 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz')
- train = mx.io.NDArrayIter(
- to4d(train_img), train_lbl, args.batch_size, shuffle=True)
- val = mx.io.NDArrayIter(
- to4d(val_img), val_lbl, args.batch_size)
- return (train, val)
- def get_symbol(num_classes=10, **kwargs):
- data = mx.symbol.Variable('data')
- data = mx.sym.Flatten(data=data)
- fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
- act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
- fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
- act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
- fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
- mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
- return mlp
- def _load_model(args, rank=0):
- if 'load_epoch' not in args or args.load_epoch is None:
- return (None, None, None)
- assert args.model_prefix is not None
- model_prefix = args.model_prefix
- if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)):
- model_prefix += "-%d" % (rank)
- sym, arg_params, aux_params = mx.model.load_checkpoint(
- model_prefix, args.load_epoch)
- logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch)
- return (sym, arg_params, aux_params)
- def _save_model(args, rank=0):
- if args.model_prefix is None:
- return None
- dst_dir = os.path.dirname(args.model_prefix)
- try:
- os.path.isdir(dst_dir)
- except FileNotFoundError:
- os.mkdir(dst_dir)
- return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % (
- args.model_prefix, rank))
- def fit(args, network, data_loader, **kwargs):
- """
- train a model
- args : argparse returns
- network : the symbol definition of the nerual network
- data_loader : function that returns the train and val data iterators
- """
-
- (train, val) = data_loader(args)
-
- checkpoint = _save_model(args, 0)
-
- devs = mx.cpu() if args.gpus is None or args.gpus == "" else [
- mx.gpu(int(i)) for i in args.gpus.split(',')]
-
- model = mx.mod.Module(
- context=devs,
- symbol=network
- )
- optimizer_params = {
- 'learning_rate': args.lr,
- 'wd': args.wd,
- 'multi_precision': True}
-
- has_momentum = {'sgd', 'dcasgd', 'nag'}
- if args.optimizer in has_momentum:
- optimizer_params['momentum'] = args.mom
- monitor = mx.mon.Monitor(
- args.monitor, pattern=".*") if args.monitor > 0 else None
-
- has_warmup = {'lbsgd', 'lbnag'}
- if args.optimizer in has_warmup:
- if 'dist' in args.kv_store:
- nworkers = kv.num_workers
- else:
- nworkers = 1
- epoch_size = args.num_examples / args.batch_size / nworkers
- if epoch_size < 1:
- epoch_size = 1
- macrobatch_size = args.macrobatch_size
- if macrobatch_size < args.batch_size * nworkers:
- macrobatch_size = args.batch_size * nworkers
-
- batch_scale = math.ceil(
- float(macrobatch_size) / args.batch_size / nworkers)
- optimizer_params['updates_per_epoch'] = epoch_size
- optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch else 0
- optimizer_params['batch_scale'] = batch_scale
- optimizer_params['warmup_strategy'] = args.warmup_strategy
- optimizer_params['warmup_epochs'] = args.warmup_epochs
- optimizer_params['num_epochs'] = args.num_epochs
- initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)
-
- eval_metrics = ['accuracy']
- supported_loss = ['ce', 'nll_loss']
- if len(args.loss) > 0:
-
- loss_type_list = args.loss.split(',')
- if 'softmax_output' in network.list_outputs():
- for loss_type in loss_type_list:
- loss_type = loss_type.strip()
- if loss_type == 'nll':
- loss_type = 'nll_loss'
- if loss_type not in supported_loss:
- logging.warning(loss_type + ' is not an valid loss type, only cross-entropy or ' 'negative likelihood loss is supported!')
- else:
- eval_metrics.append(mx.metric.create(loss_type))
- else:
- logging.warning("The output is not softmax_output, loss argument will be skipped!")
-
- batch_end_callbacks = [mx.callback.Speedometer(
- args.batch_size, args.disp_batches)]
- if 'batch_end_callback' in kwargs:
- cbs = kwargs['batch_end_callback']
- batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs]
-
- model.fit(train,
- begin_epoch=args.load_epoch if args.load_epoch else 0,
- num_epoch=args.num_epochs,
- eval_data=val,
- eval_metric=eval_metrics,
- kvstore=kv,
- optimizer=args.optimizer,
- optimizer_params=optimizer_params,
- initializer=initializer,
- arg_params=arg_params,
- aux_params=aux_params,
- batch_end_callback=batch_end_callbacks,
- epoch_end_callback=checkpoint,
- allow_missing=True,
- monitor=monitor)
- class args:
- gpus = None
- batch_size = 64
- disp_batches = 100
- num_epochs = 20
- lr = .05
- model_prefix = "minist_mlp"
- wd = 0.0001
- optimizer = 'sgd'
- mom = 0.9
- monitor = 0
- loss = ''
-
- sym = get_symbol()
- fit(args, sym, get_mnist_iter)
|