相关原理就不再介绍了, 代码贴一下:
这里解决的是李航统计学习方法书上2.1的习题1
# -*- coding:utf-8 -*-
import numpy as np
training_set = np.array([[(3, 3), 1], [(4, 3), 1], [(1, 1), -1]])
w = np.zeros(2, dtype=int)
b = 0
def update(item):
global w, b
w[0] += 1 * item[1] * item[0][0]
w[1] += 1 * item[1] * item[0][1]
b += 1 * item[1]
print w, b
def cal(item):
res = 0
for i in range(len(item[0])):
res += item[0][i] * w[i]
res += b
res *= item[1]
return res
def check():
flag = 0
for item in training_set:
if cal(item) <= 0:
flag = 1
update(item)
if not flag:
print "RESULT: w: " + str(w) + " b: " + str(b)
return flag
if __name__ == "__main__":
for i in range(1000):
if not check(): break
perceptron
socket编程
Hello World
>