pla_np.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. #!/usr/bin/env python3
  2. import numpy as np
  3. c = 0.1
  4. LIMIT = 100
  5. print_epoches = [5, 10, 100, 1000]
  6. DATA_SLICE = 5
  7. NUM_CLASS = 1
  8. INPUT_FEATURE = 22
  9. def get_one_hot(targets, nb_classes):
  10. res = np.eye(nb_classes)[np.array(targets).reshape(-1)]
  11. return res.reshape(list(targets.shape)+[nb_classes])
  12. def load(file):
  13. raw_data = np.loadtxt(file, delimiter=',')
  14. data_size = len(raw_data)
  15. np.random.shuffle(raw_data)
  16. ret = np.split(raw_data, [1,], axis=1)
  17. return ret
  18. train_pieces = load('SPECT.train')
  19. train_label = train_pieces[0]
  20. train_data = train_pieces[1]
  21. train_label = train_label.flatten().astype(int)
  22. val_pieces = load('SPECT.test')
  23. val_label = val_pieces[0]
  24. val_data = val_pieces[1]
  25. val_label = val_label.flatten().astype(int)
  26. v = np.array(val_data, dtype='float32')
  27. v = np.concatenate((v, np.ones((len(v), 1))), axis=1)
  28. def pla(train_data, train_label, val_data, val_label, picked_class):
  29. x = np.array(train_data, dtype='float32')
  30. x = np.concatenate((x, np.ones((len(x), 1))), axis=1)
  31. for i in range(len(x)):
  32. if train_label[i] != picked_class:
  33. x[i] *= -1
  34. w = np.random.rand(INPUT_FEATURE + 1)
  35. for j in range(LIMIT):
  36. flag = False
  37. for i in x:
  38. z = sum(i * w)
  39. if z <= 0:
  40. w = w + c * i
  41. flag = True
  42. if not flag:
  43. print('')
  44. break
  45. #if j in print_epoches:
  46. print(j)
  47. print('train acc:', np.sum(np.sum(x * w, axis=1) > 0) / len(x))
  48. print('val acc:', np.sum((np.sum(v * w, axis=1) > 0) == (val_label == picked_class)) / len(v))
  49. return w
  50. w = [pla(train_data, train_label, val_data, val_label, i) for i in range(NUM_CLASS)]
  51. pred = [np.sum(v * w[i], axis=1) > 0 for i in range(NUM_CLASS)]
  52. TP = sum(pred[0][i] == 1 and val_label[i] == 1 for i in range(len(val_label)))
  53. FP = sum(pred[0][i] == 1 and val_label[i] == 0 for i in range(len(val_label)))
  54. FN = sum(pred[0][i] == 0 and val_label[i] == 1 for i in range(len(val_label)))
  55. P = TP / (TP + FP)
  56. R = TP / (TP + FN)
  57. print('val acc:', sum([all([pred[j][i] == (j == val_label[i]) for j in range(NUM_CLASS)]) for i in range(len(val_label))]) / len(val_label))
  58. print('F1', 2 * P * R / (P + R))