123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- #!/usr/bin/env python3
- import numpy as np
- c = 0.1
- LIMIT = 20
- print_epoches = [5, 10, 100, 1000]
- DATA_SLICE = 5
- NUM_CLASS = 3
- INPUT_FEATURE = 13
- def load():
- raw_data = np.loadtxt('wine.data', delimiter=',')
- data_size = len(raw_data)
- np.random.shuffle(raw_data)
- data = np.array_split(raw_data, DATA_SLICE)
- ret = [np.split(i, [1,], axis=1) for i in data]
- return ret
- data_pieces = load()
- val_idx = 0
- train_label = np.concatenate(
- [data_pieces[i][0] for i in list(range(DATA_SLICE))[:val_idx] + list(range(DATA_SLICE))[val_idx+1:]]
- )
- train_label = train_label - 1
- train_label = train_label.flatten().astype(int)
- train_data = np.concatenate(
- [data_pieces[i][1] for i in list(range(DATA_SLICE))[:val_idx] + list(range(DATA_SLICE))[val_idx+1:]]
- )
- val_label = data_pieces[val_idx][0]
- val_data = data_pieces[val_idx][1]
- val_label = val_label - 1
- val_label = val_label.flatten().astype(int)
- v = np.array(val_data, dtype='float32')
- v = np.concatenate((v, np.ones((len(v), 1))), axis=1)
- def pla(train_data, train_label, val_data, val_label, picked_class):
- x = np.array(train_data, dtype='float32')
- x = np.concatenate((x, np.ones((len(x), 1))), axis=1)
- for i in range(len(x)):
- if train_label[i] != picked_class:
- x[i] *= -1
- w = np.random.rand(INPUT_FEATURE + 1)
- for j in range(LIMIT):
- flag = False
- for i in x:
- z = sum(i * w)
- if z <= 0:
- w = w + c * i
- flag = True
- if not flag:
- print('')
- break
- #if j in print_epoches:
- print(j)
- print('train acc:', np.sum(np.sum(x * w, axis=1) > 0) / len(x))
- print('val acc:', np.sum((np.sum(v * w, axis=1) > 0) == (val_label == picked_class)) / len(v))
- return w
- w = [pla(train_data, train_label, val_data, val_label, i) for i in range(NUM_CLASS)]
- pred = [np.sum(v * w[i], axis=1) > 0 for i in range(NUM_CLASS)]
- print('val acc:', sum([all([pred[j][i] == (j == val_label[i]) for j in range(NUM_CLASS)]) for i in range(len(val_label))]) / len(val_label))
|