CHEN Yihui il y a 4 ans
Parent
commit
d5df897a48
1 fichiers modifiés avec 258 ajouts et 0 suppressions
  1. 258 0
      3_10.py

+ 258 - 0
3_10.py

@@ -0,0 +1,258 @@
+#!/usr/bin/env python
+# coding: utf-8
+
+# In[1]:
+
+
+import mxnet as mx
+import numpy as np
+import gzip, struct
+import argparse
+import os
+
+
+# In[2]:
+
+
+def download_file(url, local_fname=None, force_write=False):
+    # requests is not default installed
+    import requests
+    if local_fname is None:
+        local_fname = url.split('/')[-1]
+    if not force_write and os.path.exists(local_fname):
+        return local_fname
+
+    dir_name = os.path.dirname(local_fname)
+
+    if dir_name != "":
+        if not os.path.exists(dir_name):
+            try: # try to create the directory if it doesn't exists
+                os.makedirs(dir_name)
+            except OSError as exc:
+                if exc.errno != errno.EEXIST:
+                    raise
+
+    r = requests.get(url, stream=True)
+    assert r.status_code == 200, "failed to open %s" % url
+    with open(local_fname, 'wb') as f:
+        for chunk in r.iter_content(chunk_size=1024):
+            if chunk: # filter out keep-alive new chunks
+                f.write(chunk)
+    return local_fname
+
+
+# In[3]:
+
+
+def read_data(label, image):
+    """
+    download and read data into numpy
+    """
+    base_url = 'http://yann.lecun.com/exdb/mnist/'
+    with gzip.open(download_file(base_url+label, os.path.join('data',label))) as flbl:
+        magic, num = struct.unpack(">II", flbl.read(8))
+        label = np.fromstring(flbl.read(), dtype=np.int8)
+    with gzip.open(download_file(base_url+image, os.path.join('data',image)), 'rb') as fimg:
+        magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
+        image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
+    return (label, image)
+
+
+def to4d(img):
+    """
+    reshape to 4D arrays
+    """
+    return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255
+
+def get_mnist_iter(args):
+    """
+    create data iterator with NDArrayIter
+    """
+    (train_lbl, train_img) = read_data(
+            'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz')
+    (val_lbl, val_img) = read_data(
+            't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz')
+    train = mx.io.NDArrayIter(
+        to4d(train_img), train_lbl, args.batch_size, shuffle=True)
+    val = mx.io.NDArrayIter(
+        to4d(val_img), val_lbl, args.batch_size)
+    return (train, val)
+
+
+# In[4]:
+
+
+def get_symbol(num_classes=10, **kwargs):
+    data = mx.symbol.Variable('data')
+    data = mx.sym.Flatten(data=data)
+    fc1  = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
+    act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
+    fc2  = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
+    act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
+    fc3  = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
+    mlp  = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
+    return mlp
+
+
+# In[5]:
+
+
+def _load_model(args, rank=0):
+    if 'load_epoch' not in args or args.load_epoch is None:
+        return (None, None, None)
+    assert args.model_prefix is not None
+    model_prefix = args.model_prefix
+    if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)):
+        model_prefix += "-%d" % (rank)
+    sym, arg_params, aux_params = mx.model.load_checkpoint(
+        model_prefix, args.load_epoch)
+    logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch)
+    return (sym, arg_params, aux_params)
+
+
+def _save_model(args, rank=0):
+    if args.model_prefix is None:
+        return None
+    dst_dir = os.path.dirname(args.model_prefix)
+    try:
+        os.path.isdir(dst_dir)
+    except FileNotFoundError:
+        os.mkdir(dst_dir)
+    return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % (
+        args.model_prefix, rank))
+
+
+# In[15]:
+
+
+def fit(args, network, data_loader, **kwargs):
+    """
+    train a model
+    args : argparse returns
+    network : the symbol definition of the nerual network
+    data_loader : function that returns the train and val data iterators
+    """
+    # data iterators
+    (train, val) = data_loader(args)
+
+    # save model
+    checkpoint = _save_model(args, 0)
+
+    # devices for training
+    devs = mx.cpu() if args.gpus is None or args.gpus == "" else [
+        mx.gpu(int(i)) for i in args.gpus.split(',')]
+
+    # create model
+    model = mx.mod.Module(
+        context=devs,
+        symbol=network
+    )
+
+    optimizer_params = {
+        'learning_rate': args.lr,
+        'wd': args.wd,
+        'multi_precision': True}
+
+    # Only a limited number of optimizers have 'momentum' property
+    has_momentum = {'sgd', 'dcasgd', 'nag'}
+    if args.optimizer in has_momentum:
+        optimizer_params['momentum'] = args.mom
+
+    monitor = mx.mon.Monitor(
+        args.monitor, pattern=".*") if args.monitor > 0 else None
+
+    # A limited number of optimizers have a warmup period
+    has_warmup = {'lbsgd', 'lbnag'}
+    if args.optimizer in has_warmup:
+        if 'dist' in args.kv_store:
+            nworkers = kv.num_workers
+        else:
+            nworkers = 1
+        epoch_size = args.num_examples / args.batch_size / nworkers
+        if epoch_size < 1:
+            epoch_size = 1
+        macrobatch_size = args.macrobatch_size
+        if macrobatch_size < args.batch_size * nworkers:
+            macrobatch_size = args.batch_size * nworkers
+        #batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999)
+        batch_scale = math.ceil(
+            float(macrobatch_size) / args.batch_size / nworkers)
+        optimizer_params['updates_per_epoch'] = epoch_size
+        optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch else 0
+        optimizer_params['batch_scale'] = batch_scale
+        optimizer_params['warmup_strategy'] = args.warmup_strategy
+        optimizer_params['warmup_epochs'] = args.warmup_epochs
+        optimizer_params['num_epochs'] = args.num_epochs
+
+    initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)
+
+    # evaluation metrices
+    eval_metrics = ['accuracy']
+
+    supported_loss = ['ce', 'nll_loss']
+    if len(args.loss) > 0:
+        # ce or nll loss is only applicable to softmax output
+        loss_type_list = args.loss.split(',')
+        if 'softmax_output' in network.list_outputs():
+            for loss_type in loss_type_list:
+                loss_type = loss_type.strip()
+                if loss_type == 'nll':
+                    loss_type = 'nll_loss'
+                if loss_type not in supported_loss:
+                    logging.warning(loss_type + ' is not an valid loss type, only cross-entropy or '                                     'negative likelihood loss is supported!')
+                else:
+                    eval_metrics.append(mx.metric.create(loss_type))
+        else:
+            logging.warning("The output is not softmax_output, loss argument will be skipped!")
+
+    # callbacks that run after each batch
+    batch_end_callbacks = [mx.callback.Speedometer(
+        args.batch_size, args.disp_batches)]
+    if 'batch_end_callback' in kwargs:
+        cbs = kwargs['batch_end_callback']
+        batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs]
+
+    # run
+    model.fit(train,
+              begin_epoch=args.load_epoch if args.load_epoch else 0,
+              num_epoch=args.num_epochs,
+              eval_data=val,
+              eval_metric=eval_metrics,
+              kvstore=kv,
+              optimizer=args.optimizer,
+              optimizer_params=optimizer_params,
+              initializer=initializer,
+              arg_params=arg_params,
+              aux_params=aux_params,
+              batch_end_callback=batch_end_callbacks,
+              epoch_end_callback=checkpoint,
+              allow_missing=True,
+              monitor=monitor)
+
+
+# In[ ]:
+
+
+class args:
+    gpus           = None
+    batch_size     = 64
+    disp_batches   = 100
+    num_epochs     = 20
+    lr             = .05
+    model_prefix   = "minist_mlp"
+    wd             = 0.0001
+    optimizer      = 'sgd'
+    mom            = 0.9
+    monitor        = 0
+    loss           = ''
+
+    
+sym = get_symbol()
+fit(args, sym, get_mnist_iter)
+
+
+# In[ ]:
+
+
+
+