博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
基于tensorflow的逻辑分类
阅读量:5240 次
发布时间:2019-06-14

本文共 3844 字,大约阅读时间需要 12 分钟。

#!/usr/local/bin/python3##ljj [2]##logic classify model import tensorflow as tfimport matplotlib.pyplot as pltimport pandas as pdimport numpy as npdata_set = pd.read_csv('LogiReg_data.txt',sep=',')#data_set.describe()w = tf.Variable(tf.random_normal([2,1]),dtype="float32")b = tf.Variable(tf.random_normal([1]),dtype="float32")y = tf.placeholder(tf.float32)x = tf.placeholder(tf.float32,shape=(1,2))loss_list = []with tf.Session() as sess:#定义逻辑回归模型        logits = tf.add(tf.matmul(x,w),b)        y_predict = tf.nn.sigmoid(logits)                loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,labels=y))                   train = tf.train.AdamOptimizer(0.001).minimize(loss)        sess.run(tf.global_variables_initializer())                for j in range(1500):                #shuffle data_set                #index = np.random.permutation(100)                #data_set = data_set.reindex(index)                                for i in range(100):                        w_,b_,loss_,_ = sess.run([w,b,loss,train],feed_dict={x:data_set[i:i+1][['math','english']],y:data_set[i:i+1][['result']]})                if j%100==0:                        print('epoch={}, w={},b={},loss={}'.format(j,w_,b_,loss_))                loss_list.append(loss_)        print('final result : ')        print('w={},b={},loss={}'.format(w_,b_,loss_))              train_data = data_set.values        x1 = train_data[:,0]        x2 = train_data[:,1]        y = train_data[:,-1:]                  for x1p, x2p, yp in zip(x1, x2, y):                if yp == 0:                        plt.scatter(x1p, x2p, marker='x', c='r')                else:                        plt.scatter(x1p, x2p, marker='o', c='g')                        # 根据参数得到直线        x = np.linspace(20, 100, 10)        y = []        for i in x:            y.append((i * -w_[1] - b_) / w_[0])                    plt.plot(x, y)        plt.show()

 运行输出:

ljjdeMBP:logic_classify lingjiajun$ ./logic_regression.py 

/usr/local/Cellar/python3/3.6.2/Frameworks/Python.framework/Versions/3.6/lib/python3.6/importlib/_bootstrap.py:205: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6

  return f(*args, **kwds)

2018-05-06 21:48:14.420588: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA

epoch=0, w=[[-0.87034857]

 [ 0.13095166]],b=[ 1.48850453],loss=52.04541015625

epoch=100, w=[[ 0.01144427]

 [ 0.0005749 ]],b=[ 0.58342618],loss=0.21249079704284668

epoch=200, w=[[ 0.02113499]

 [ 0.01194377]],b=[-0.85776216],loss=0.16103345155715942

epoch=300, w=[[ 0.03020949]

 [ 0.02251359]],b=[-2.18144464],loss=0.12141948938369751

epoch=400, w=[[ 0.03859403]

 [ 0.03212684]],b=[-3.3784802],loss=0.092116579413414

epoch=500, w=[[ 0.04626059]

 [ 0.04076466]],b=[-4.4525094],loss=0.07090871036052704

epoch=600, w=[[ 0.05323409]

 [ 0.04850558]],b=[-5.41535854],loss=0.05559533089399338

epoch=700, w=[[ 0.059574  ]

 [ 0.05546409]],b=[-6.28165531],loss=0.04442552104592323

epoch=800, w=[[ 0.06535295]

 [ 0.06175429]],b=[-7.06552744],loss=0.03614450991153717

epoch=900, w=[[ 0.07064275]

 [ 0.067476  ]],b=[-7.77926588],loss=0.029891693964600563

epoch=1000, w=[[ 0.07550841]

 [ 0.07271299]],b=[-8.43318176],loss=0.02508264034986496

epoch=1100, w=[[ 0.08000626]

 [ 0.07753391]],b=[-9.0357523],loss=0.021319210529327393

epoch=1200, w=[[ 0.08418395]

 [ 0.0819957 ]],b=[-9.59397316],loss=0.01832636632025242

epoch=1300, w=[[ 0.08808059]

 [ 0.08614379]],b=[-10.11344337],loss=0.015911955386400223

epoch=1400, w=[[ 0.09172987]

 [ 0.09001698]],b=[-10.59893036],loss=0.01393833290785551

final result : 

w=[[ 0.09512767]

 [ 0.09361333]],b=[-11.05011368],loss=0.012320424430072308

 

转载于:https://www.cnblogs.com/lingjiajun/p/9000024.html

你可能感兴趣的文章
Html 小插件5 百度搜索代码2
查看>>
nodejs-Path模块
查看>>
P1107 最大整数
查看>>
EasyDarwin开源手机直播方案:EasyPusher手机直播推送,EasyDarwin流媒体服务器,EasyPlayer手机播放器...
查看>>
监控CPU和内存的使用
查看>>
Ubuntu14.04设置开机自启动程序
查看>>
ios app 单元测试 自动化测试
查看>>
年薪二十万
查看>>
强连通tarjan模版
查看>>
javascript_09-数组
查看>>
多进程与多线程的区别
查看>>
PAT 1145 1078| hashing哈希表 平方探测法
查看>>
Ubuntu(虚拟机)下安装Qt5.5.1
查看>>
Linux第七周学习总结——可执行程序的装载
查看>>
java.io.IOException: read failed, socket might closed or timeout, read ret: -1
查看>>
细说php(二) 变量和常量
查看>>
iOS开发网络篇之Web Service和XML数据解析
查看>>
个人寒假作业项目《印象笔记》第一天
查看>>
java 常用命令
查看>>
ZOJ 1666 G-Square Coins
查看>>