3_10.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. import mxnet as mx
  5. import numpy as np
  6. import gzip, struct
  7. import argparse
  8. import os
  9. # In[2]:
  10. def download_file(url, local_fname=None, force_write=False):
  11. # requests is not default installed
  12. import requests
  13. if local_fname is None:
  14. local_fname = url.split('/')[-1]
  15. if not force_write and os.path.exists(local_fname):
  16. return local_fname
  17. dir_name = os.path.dirname(local_fname)
  18. if dir_name != "":
  19. if not os.path.exists(dir_name):
  20. try: # try to create the directory if it doesn't exists
  21. os.makedirs(dir_name)
  22. except OSError as exc:
  23. if exc.errno != errno.EEXIST:
  24. raise
  25. r = requests.get(url, stream=True)
  26. assert r.status_code == 200, "failed to open %s" % url
  27. with open(local_fname, 'wb') as f:
  28. for chunk in r.iter_content(chunk_size=1024):
  29. if chunk: # filter out keep-alive new chunks
  30. f.write(chunk)
  31. return local_fname
  32. # In[3]:
  33. def read_data(label, image):
  34. """
  35. download and read data into numpy
  36. """
  37. base_url = 'http://yann.lecun.com/exdb/mnist/'
  38. with gzip.open(download_file(base_url+label, os.path.join('data',label))) as flbl:
  39. magic, num = struct.unpack(">II", flbl.read(8))
  40. label = np.fromstring(flbl.read(), dtype=np.int8)
  41. with gzip.open(download_file(base_url+image, os.path.join('data',image)), 'rb') as fimg:
  42. magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
  43. image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
  44. return (label, image)
  45. def to4d(img):
  46. """
  47. reshape to 4D arrays
  48. """
  49. return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255
  50. def get_mnist_iter(args):
  51. """
  52. create data iterator with NDArrayIter
  53. """
  54. (train_lbl, train_img) = read_data(
  55. 'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz')
  56. (val_lbl, val_img) = read_data(
  57. 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz')
  58. train = mx.io.NDArrayIter(
  59. to4d(train_img), train_lbl, args.batch_size, shuffle=True)
  60. val = mx.io.NDArrayIter(
  61. to4d(val_img), val_lbl, args.batch_size)
  62. return (train, val)
  63. # In[4]:
  64. def get_symbol(num_classes=10, **kwargs):
  65. data = mx.symbol.Variable('data')
  66. data = mx.sym.Flatten(data=data)
  67. fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
  68. act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
  69. fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
  70. act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
  71. fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
  72. mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
  73. return mlp
  74. # In[5]:
  75. def _load_model(args, rank=0):
  76. if 'load_epoch' not in args or args.load_epoch is None:
  77. return (None, None, None)
  78. assert args.model_prefix is not None
  79. model_prefix = args.model_prefix
  80. if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)):
  81. model_prefix += "-%d" % (rank)
  82. sym, arg_params, aux_params = mx.model.load_checkpoint(
  83. model_prefix, args.load_epoch)
  84. logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch)
  85. return (sym, arg_params, aux_params)
  86. def _save_model(args, rank=0):
  87. if args.model_prefix is None:
  88. return None
  89. dst_dir = os.path.dirname(args.model_prefix)
  90. try:
  91. os.path.isdir(dst_dir)
  92. except FileNotFoundError:
  93. os.mkdir(dst_dir)
  94. return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % (
  95. args.model_prefix, rank))
  96. # In[15]:
  97. def fit(args, network, data_loader, **kwargs):
  98. """
  99. train a model
  100. args : argparse returns
  101. network : the symbol definition of the nerual network
  102. data_loader : function that returns the train and val data iterators
  103. """
  104. # data iterators
  105. (train, val) = data_loader(args)
  106. # save model
  107. checkpoint = _save_model(args, 0)
  108. # devices for training
  109. devs = mx.cpu() if args.gpus is None or args.gpus == "" else [
  110. mx.gpu(int(i)) for i in args.gpus.split(',')]
  111. # create model
  112. model = mx.mod.Module(
  113. context=devs,
  114. symbol=network
  115. )
  116. optimizer_params = {
  117. 'learning_rate': args.lr,
  118. 'wd': args.wd,
  119. 'multi_precision': True}
  120. # Only a limited number of optimizers have 'momentum' property
  121. has_momentum = {'sgd', 'dcasgd', 'nag'}
  122. if args.optimizer in has_momentum:
  123. optimizer_params['momentum'] = args.mom
  124. monitor = mx.mon.Monitor(
  125. args.monitor, pattern=".*") if args.monitor > 0 else None
  126. # A limited number of optimizers have a warmup period
  127. has_warmup = {'lbsgd', 'lbnag'}
  128. if args.optimizer in has_warmup:
  129. if 'dist' in args.kv_store:
  130. nworkers = kv.num_workers
  131. else:
  132. nworkers = 1
  133. epoch_size = args.num_examples / args.batch_size / nworkers
  134. if epoch_size < 1:
  135. epoch_size = 1
  136. macrobatch_size = args.macrobatch_size
  137. if macrobatch_size < args.batch_size * nworkers:
  138. macrobatch_size = args.batch_size * nworkers
  139. #batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999)
  140. batch_scale = math.ceil(
  141. float(macrobatch_size) / args.batch_size / nworkers)
  142. optimizer_params['updates_per_epoch'] = epoch_size
  143. optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch else 0
  144. optimizer_params['batch_scale'] = batch_scale
  145. optimizer_params['warmup_strategy'] = args.warmup_strategy
  146. optimizer_params['warmup_epochs'] = args.warmup_epochs
  147. optimizer_params['num_epochs'] = args.num_epochs
  148. initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)
  149. # evaluation metrices
  150. eval_metrics = ['accuracy']
  151. supported_loss = ['ce', 'nll_loss']
  152. if len(args.loss) > 0:
  153. # ce or nll loss is only applicable to softmax output
  154. loss_type_list = args.loss.split(',')
  155. if 'softmax_output' in network.list_outputs():
  156. for loss_type in loss_type_list:
  157. loss_type = loss_type.strip()
  158. if loss_type == 'nll':
  159. loss_type = 'nll_loss'
  160. if loss_type not in supported_loss:
  161. logging.warning(loss_type + ' is not an valid loss type, only cross-entropy or ' 'negative likelihood loss is supported!')
  162. else:
  163. eval_metrics.append(mx.metric.create(loss_type))
  164. else:
  165. logging.warning("The output is not softmax_output, loss argument will be skipped!")
  166. # callbacks that run after each batch
  167. batch_end_callbacks = [mx.callback.Speedometer(
  168. args.batch_size, args.disp_batches)]
  169. if 'batch_end_callback' in kwargs:
  170. cbs = kwargs['batch_end_callback']
  171. batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs]
  172. # run
  173. model.fit(train,
  174. begin_epoch=args.load_epoch if args.load_epoch else 0,
  175. num_epoch=args.num_epochs,
  176. eval_data=val,
  177. eval_metric=eval_metrics,
  178. kvstore=kv,
  179. optimizer=args.optimizer,
  180. optimizer_params=optimizer_params,
  181. initializer=initializer,
  182. arg_params=arg_params,
  183. aux_params=aux_params,
  184. batch_end_callback=batch_end_callbacks,
  185. epoch_end_callback=checkpoint,
  186. allow_missing=True,
  187. monitor=monitor)
  188. # In[ ]:
  189. class args:
  190. gpus = None
  191. batch_size = 64
  192. disp_batches = 100
  193. num_epochs = 20
  194. lr = .05
  195. model_prefix = "minist_mlp"
  196. wd = 0.0001
  197. optimizer = 'sgd'
  198. mom = 0.9
  199. monitor = 0
  200. loss = ''
  201. sym = get_symbol()
  202. fit(args, sym, get_mnist_iter)
  203. # In[ ]: