每日一博 | 使用逻辑回归通过歌曲列表预判用户性别
前言
上一篇写了推荐系统最古老的的一种算法叫协同过滤,古老并不是不好用,其实还是很好用的一种算法,随着时代的进步,出现了神经网络和因子分解等更优秀的算法解决不同的问题。 这里主要说一下逻辑回归,逻辑回归主要用于打分的预估。我这里没有打分的数据所以用性别代替。 这里的例子就是用歌曲列表预判用户性别。
什么是逻辑回归
逻辑回归的资料比较多,我比较推荐大家看刷一下bilibili上李宏毅老师的视频,这里我只说一些需要注意的点。
网络结构
逻辑回归可以理解为一种单层神经网络,网络结构如图:
激活函数选择
逻辑回归一般选sigmoid或者softmax
- 图的上半部分就是二元逻辑回归激活函数是sigmoid
- 图的下半部分是多元逻辑回归没有激活函数直接接了一个softmax
别问我啥是sigmoid啥是softmax,问就是百度。
损失函数选择
损失函数逻辑回归常用的有三种(其实有很多不止三种,自己查API喽):
- binary_crossentropy
- categorical_crossentropy
- sparse_categorical_crossentrop 这里其实用binary更合适,但是我这里选的categorical_crossentropy,因为我懒得改了,而且我后面会做其他功能
梯度下降选择
梯度下降方式有很多,我这里选择随机梯度下降,sgd其实我觉得adam更合适,看大家心情了。至于为啥
数据准备
这次的数据是1万条KTV唱歌数据,别问我数据哪来的。问就是别人给的。
X是用户唱歌数据的one-hot
Y是用户的性别one-hot
下面是真正的技术
代码实现
- 数据拆分为 80%训练 20%测试
- 这里虽然只有两类但是还是用了softmax,不影响
- 训练工具是keras
数据获取
下面代码都干了些啥呢,主要是两个matrix。
一个是用户唱歌的onehot->song_hot_matrix。
一个是用户性别的onehot->decades_hot_matrix。 代码不重要,主要看字。
import elasticsearch import elasticsearch.helpers import re import numpy as np import operator import datetime es_client = elasticsearch.Elasticsearch(hosts=["localhost:9200"]) def trim_song_name(song_name): """ 处理歌名,过滤掉无用内容和空白 """ song_name = song_name.strip() song_name = re.sub("【.*?】", "", song_name) song_name = re.sub("(.*?)", "", song_name) return song_name def trim_address_name(address_name): """ 处理地址 """ return str(address_name).strip() def get_data(size=0): """ 获取uid=>作品名list的字典 """ cur_size=0 song_dic = {} user_address_dic = {} user_decades_dic = {} search_result = elasticsearch.helpers.scan( es_client, index="ktv_user_info", doc_type="ktv_works", scroll="10m", query={ "query":{ "range": { "birthday": { "gt": 63072662400 } } } } ) for hit_item in search_result: cur_size += 1 if size>0 and cur_size>size: break user_info = hit_item["_source"] item = get_work_info(hit_item["_id"]) if item is None: continue work_list = item['item_list'] if len(work_list)<2: continue if user_info['gender']==0: continue if user_info['gender']==1: user_info['gender']="男" if user_info['gender']==2: user_info['gender']="女" song_dic[item['uid']] = [trim_song_name(item['songname']) for item in work_list] user_decades_dic[item['uid']] = user_info['gender'] user_address_dic[item['uid']] = trim_address_name(user_info['address']) return (song_dic, user_address_dic, user_decades_dic) def get_user_info(uid): """ 获取用户信息 """ ret = es_client.get( index="ktv_user_info", doc_type="ktv_works", id=uid ) return ret['_source'] def get_work_info(uid): """ 获取用户信息 """ try: ret = es_client.get( index="ktv_works", doc_type="ktv_works", id=uid ) return ret['_source'] except Exception as ex: return None def get_uniq_song_sort_list(song_dict): """ 合并重复歌曲并按歌曲名排序 """ return sorted(list(set(np.concatenate(list(song_dict.values())).tolist()))) from sklearn import preprocessing %run label_encoder.ipynb user_count = 4000 song_count = 0 # 获得用户唱歌数据 song_dict, user_address_dict, user_decades_dict = get_data(user_count) # 歌曲字典 song_label_encoder = LabelEncoder() song_label_encoder.fit_dict(song_dict, "", True) song_hot_matrix = song_label_encoder.encode_hot_dict(song_dict, True) user_decades_encoder = LabelEncoder() user_decades_encoder.fit_dict(user_decades_dict) decades_hot_matrix = user_decades_encoder.encode_hot_dict(user_decades_dict, False)
song_hot_matrix
uid | 洗刷刷 | 麻雀 | 你的答案 |
---|---|---|---|
0 | 0 | 1 | 0 |
1 | 1 | 1 | 0 |
2 | 1 | 0 | 0 |
3 | 0 | 0 | 0 |
decades_hot_matrix
uid | 男 | 女 |
---|---|---|
0 | 1 | 0 |
1 | 0 | 1 |
2 | 1 | 0 |
3 | 0 | 1 |
模型训练
import numpy as np from keras.models import Sequential from keras.layers import Dense, Activation, Embedding,Flatten import matplotlib.pyplot as plt from keras.utils import np_utils from sklearn import datasets from sklearn.model_selection import train_test_split n_class=user_decades_encoder.get_class_count() song_count=song_label_encoder.get_class_count() print(n_class) print(song_count) # 拆分训练数据和测试数据 train_X,test_X, train_y, test_y = train_test_split(song_hot_matrix, decades_hot_matrix, test_size = 0.2, random_state = 0) train_count = np.shape(train_X)[0] # 构建神经网络模型 model = Sequential() model.add(Dense(input_dim=8, units=n_class)) model.add(Activation('softmax')) # 选定loss函数和优化器 model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) # 训练过程 print('Training -----------') for step in range(train_count): scores = model.train_on_batch(train_X, train_y) if step % 50 == 0: print("训练样本 %d 个, 损失: %f, 准确率: %f" % (step, scores[0], scores[1]*100)) print('finish!')
准确率测试集评估
数据训练完了用拆分出来的20%数据测试一下:
# 准确率评估 from sklearn.metrics import classification_report scores = model.evaluate(test_X, test_y, verbose=0) print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100)) Y_test = np.argmax(test_y, axis=1) y_pred = model.predict_classes(test_X) print(classification_report(Y_test, y_pred))
输出:
accuracy: 78.43% precision recall f1-score support 0 0.72 0.90 0.80 220 1 0.88 0.68 0.77 239 accuracy 0.78 459 macro avg 0.80 0.79 0.78 459 weighted avg 0.80 0.78 0.78 459
人工测试
然后让小伙伴们一起来玩耍,嗯准确率100%,完美!
def pred(song_list=[]): blong_hot_matrix = song_label_encoder.encode_hot_dict({"bblong":song_list}, True) y_pred = model.predict_classes(blong_hot_matrix) return user_decades_encoder.decode_list(y_pred) # # 男A # print(pred(["一路向北", "暗香", "菊花台"])) # # 男B # print(pred(["不要说话", "平凡之路", "李白"])) # # 女A # print(pred(["知足", "被风吹过的夏天", "龙卷风"])) # # 男C # print(pred(["情人","再见","无赖","离人","你的样子"])) # # 男D # print(pred(["小情歌","我好想你","无与伦比的美丽"])) # # 男E # print(pred(["忐忑","最炫民族风","小苹果"]))
低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
- 上一篇
FreeBSD 对 802.11ac 的支持进度太缓慢,基金会终于出手赞助
尽管 Windows 和 Linux 都提供了对 802.11ac("WiFi 5")的良好支持,并且它他们最新的无线芯片组已经开始将重心放在 802.11ax ("WiFi 6")上,但 FreeBSD 仍在解决支持802.11ac 标准的问题。因此,FreeBSD 基金会准备尽快开始为支持 802.11ac 的开发工作提供赞助。 英特尔一直是增加FreeBSD 硬件支持的主要硬件供应商之一,而 FreeBSD 基金会一直在为开发人员购买笔记本电脑,以增强 FreeBSD 对现代笔记本电脑的支持。但是即使做出了这样的努力,并且大多数 FreeBSD 安装在服务器上,其对 802.11ac 的支持仍然是比较落后。 不少用户对此也比较着急,他们在推特上催促 FreeBSD 基金会尽快赞助 FreeBSD,以推进对802.11ac 支持的开发工作。 但这项工作并没那么容易,我们也看到 FreeBSD.org Wiki 页面概述了许多在 FreeBSD 中支持 802.11ac 无线标准尚需完成的项目。目前看来,还有很多工作要做,需要数年才能完成。 而现在如果有了 FreeBSD 基金会的赞...
- 下一篇
2020年云迁移比较大的挑战
【金融特辑】光大银行科技部DBA女神带你从0到1揭秘MGR 云计算越来越受欢迎,但是希望在竞争中脱颖而出的企业2020年必须克服其挑战。 2020年将是云迁移和采用以及云支出将猛增多达17%的一年。 随着越来越多的企业加入云计算潮流,许多人会发现,云计算之旅并非总是一帆风顺。相同的挑战一次又一次地出现,通常使企业无法从云中获得真正的好处。 选择哪种云计算提供商? 在开始进行企业云迁移的旅程时,企业通常会获得支持,以使用卓越云中心进行开发或在公共云提供商中构建整个产品。但是,很多企业都在做好下一步措施:将鸡蛋放入哪个篮子?也就是选择哪个云计算提供商? 传统上,企业没有针对像AWS和谷歌这样的云计算提供商的企业支持结构(对微软公司情况则不同,通常具有良好的支持关系)。取而代之的是,他们决定采用征求建议书或选择更容易加入的供应商。这两种方法都没有考虑到对工程和开发人员体验的影响。 那么,企业应该采用哪种方法?也许企业已经有一个由特定云计算提供商培训的工程师小团队。也许有一款需要开发的产品,可以很好地使用某个云计算提供商提供的服务。或者,企业的业务运营和客户所处的地区最适合某个云计算提供商的地...
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
-
Docker使用Oracle官方镜像安装(12C,18C,19C)
- Springboot2将连接池hikari替换为druid,体验最强大的数据库连接池
- CentOS8编译安装MySQL8.0.19
- Docker快速安装Oracle11G,搭建oracle11g学习环境
- SpringBoot2配置默认Tomcat设置,开启更多高级功能
- MySQL8.0.19开启GTID主从同步CentOS8
- CentOS7,8上快速安装Gitea,搭建Git服务器
- Jdk安装(Linux,MacOS,Windows),包含三大操作系统的最全安装
- SpringBoot2编写第一个Controller,响应你的http请求并返回结果
推荐阅读
最新文章
- CentOS6,CentOS7官方镜像安装Oracle11G
- Windows10,CentOS7,CentOS8安装Nodejs环境
- CentOS8编译安装MySQL8.0.19
- SpringBoot2整合Thymeleaf,官方推荐html解决方案
- 设置Eclipse缩进为4个空格,增强代码规范
- CentOS7,8上快速安装Gitea,搭建Git服务器
- Windows10,CentOS7,CentOS8安装MongoDB4.0.16
- CentOS7安装Docker,走上虚拟化容器引擎之路
- CentOS6,7,8上安装Nginx,支持https2.0的开启
- CentOS7编译安装Cmake3.16.3,解决mysql等软件编译问题