feat(todo): 完善测试的算法

This commit is contained in:
eson 2021-03-26 17:14:06 +08:00
parent fbcbeca35d
commit a7ce020942
2 changed files with 142 additions and 65 deletions

202
data.py
View File

@ -1,3 +1,6 @@
import logging
import traceback
from keras.models import Sequential from keras.models import Sequential
from keras.layers import Dense, Dropout, Embedding from keras.layers import Dense, Dropout, Embedding
from keras.layers import InputLayer from keras.layers import InputLayer
@ -10,7 +13,25 @@ import os
import numpy import numpy
import time, datetime import time, datetime
def get_collect(region = "all"): regions = [
"all",
"Region_Arab",
"Region_China",
"Region_English",
"Region_Germany",
"Region_India",
"Region_Indonesia",
"Region_Japan",
"Region_Philippines",
"Region_Portuguese",
"Region_Russian",
"Region_Spanish",
"Region_Thailand",
"Region_Turkey",
"Region_Vietnam",
]
def get_collect():
collect = {} collect = {}
loadfile = "./collect.pickle" loadfile = "./collect.pickle"
@ -18,66 +39,120 @@ def get_collect(region = "all"):
collect = pickle.load(open(loadfile, 'rb')) collect = pickle.load(open(loadfile, 'rb'))
except Exception as e: except Exception as e:
print(e) print(e)
try:
# 打开数据库连接 # 打开数据库连接
db = pymysql.connect(host="sg-board1.livenono.com", port=3306,user="root",passwd="Nono-databoard",db="databoard",charset="utf8") db = pymysql.connect(host="sg-board1.livenono.com", port=3306,user="root",passwd="Nono-databoard",db="databoard",charset="utf8")
# 使用 cursor() 方法创建一个游标对象 cursor # 使用 cursor() 方法创建一个游标对象 cursor
cursor = db.cursor() cursor = db.cursor()
today = time.strftime("%Y-%m-%d", time.localtime()) today = time.strftime("%Y-%m-%d", time.localtime())
# 使用 execute() 方法执行 SQL 查询 for region in regions:
cursor.execute('''SELECT coin, extra_coins, pay_users, create_at from pay_items_hour pih where region = "%s" and platform="all" and create_at <= %s''',region , today)
collect_pay = {}
for row in cursor.fetchall():
# print(row)
coin, extra_coins, pay_users, create_at = row
d = str(create_at.date())
if d in collect_pay:
collect_pay[d].append(row)
else:
collect_pay[d] = [ row ]
# print(dir(create_at), create_at.timestamp(), create_at.date())
print('共查找出', cursor.rowcount, '条数据')
deletelist = []
for k in collect_pay:
if len(collect_pay[k]) != 24:
deletelist.append(k)
for k in deletelist: # 使用 execute() 方法执行 SQL 查询
del collect_pay[k] cursor.execute('''SELECT coin, extra_coins, pay_users, create_at from pay_items_hour pih where region = %s and country = "all" and platform="all" and create_at >= "2021-02-23" and create_at <= %s''',(region , today))
collect_pay = []
for row in cursor.fetchall():
# print(row)
coin, extra_coins, pay_users, create_at = row
rowlist = [coin, extra_coins, pay_users, create_at.hour, create_at]
# print(dir(create_at), create_at.hour)
collect_pay.append(rowlist)
# d = str(create_at.date())
# if d in collect_pay:
# collect_pay.append(row)
# else:
# collect_pay[d] = [ row ]
# print(dir(create_at), create_at.timestamp(), create_at.date())
print('共查找出', cursor.rowcount, '条数据')
querydate= []
for k in collect_pay:
querydate.append(k)
querydate.sort()
cursor.execute(
'''SELECT coin, users, create_at from gift_items_hour pih where region = "all" and create_at >= %s and create_at <= %s''',
(querydate[0], querydate[-1]),
)
collect_gift = {}
for row in cursor.fetchall():
coin, users, create_at = row
d = str(create_at.date())
if d in collect_gift:
collect_gift[d].append(row)
else:
collect_gift[d] = [ row ]
for k in collect_pay: # if cursor.rowcount <= 500:
l = collect_pay[k] # collect["pay-" + region] = None
l.sort(key=lambda x:x[3]) # collect["gift-" + region] = None
# continue
for k in collect_gift: # deletelist = []
l = collect_gift[k] # for k in collect_pay:
l.sort(key=lambda x:x[2]) # if len(collect_pay[k]) != 24:
# deletelist.append(k)
collect["pay"] = collect_pay # for k in deletelist:
collect["gift"] = collect_gift # del collect_pay[k]
# querydate= []
# for k in collect_pay:
# querydate.append(k)
# querydate.sort()
cursor.execute(
'''SELECT coin, users, create_at from gift_items_hour pih where region = %s and country = "all" and create_at >= "2021-02-23" and create_at <= %s''',
(region, today),
)
collect_gift = []
for row in cursor.fetchall():
coin, users, create_at = row
rowlist = [coin, users, create_at.hour, create_at]
collect_gift.append(rowlist)
# d = str(create_at.date())
# if d in collect_gift:
# collect_gift[d].append(row)
# else:
# collect_gift[d] = [ row ]
collect_pay.sort(key=lambda x:x[-1])
collect_gift.sort(key=lambda x:x[-1])
# for k in collect_gift:
# l = collect_gift[k]
# l.sort(key=lambda x:x[2])
yesterday = {}
for v in collect_pay:
print(v[-1])
date = (v[-1].date() - datetime.timedelta(days=1)).__str__()
print(date)
if date not in yesterday:
cursor.execute(
'''SELECT coin, extra_coins, pay_users, create_at from pay_items_day pid where region = %s and country = "all" and platform="all" and create_at = %s''',
(region , date),
)
row = cursor.fetchone()
coin, extra_coins, pay_users, create_at = row
yesterday[date] = coin + extra_coins
v.insert(-2, yesterday[date])
yesterday = {}
for v in collect_gift:
print(v[-1])
date = (v[-1].date() - datetime.timedelta(days=1)).__str__()
print(date)
if date not in yesterday:
cursor.execute(
'''SELECT coin, users, create_at from gift_items_day where region = %s and country = "all" and create_at = %s''',
(region , date),
)
row = cursor.fetchone()
coin, users, create_at = row
yesterday[date] = coin
v.insert(-2, yesterday[date])
collect["pay-" + region] = collect_pay
collect["gift-" + region] = collect_gift
except Exception as e:
# print(e)
logging.error(traceback.format_exc())
pickle.dump(collect, open(loadfile, 'wb+')) pickle.dump(collect, open(loadfile, 'wb+'))
finally: finally:
return collect return collect
@ -93,11 +168,13 @@ def load_pay_data(textNum = 80, region = "all"):
x_train = [] x_train = []
y_train = [] y_train = []
collect_pay = [] rkey = "pay-" + region
for k in collect["pay"]: collect_pay = collect[rkey]
collect_pay.append(collect["pay"][k])
# for k in collect[rkey]:
# collect_pay.append(collect[rkey][k])
collect_pay.sort(key=lambda x:x[0][3]) # collect_pay.sort(key=lambda x:x[0][3])
lastday_v = collect_pay[0] lastday_v = collect_pay[0]
for cur_v in collect_pay[1:]: for cur_v in collect_pay[1:]:
@ -105,9 +182,6 @@ def load_pay_data(textNum = 80, region = "all"):
users = 0 users = 0
last_total_coin = 0 last_total_coin = 0
for v2 in lastday_v:
last_total_coin += v2[0] + v2[1]
count = 0 count = 0
for v1, v2 in zip(cur_v,lastday_v): for v1, v2 in zip(cur_v,lastday_v):
total_coin += v1[0] + v1[1] total_coin += v1[0] + v1[1]
@ -120,7 +194,7 @@ def load_pay_data(textNum = 80, region = "all"):
# print(compare) # print(compare)
# 时刻. 前一个小时 时刻. 当前支付总币数. 当前支付总币数 昨天币数 # 时刻. 前一个小时 时刻. 当前支付总币数. 当前支付总币数 昨天币数
x_train.append([count ,total_coin / last_total_coin , total_coin]) x_train.append([v1[-2] ,total_coin / v1[-3] , total_coin])
count+=1 count+=1
for i in range(count): for i in range(count):
@ -141,16 +215,18 @@ def load_pay_data(textNum = 80, region = "all"):
return x_train, y_train, tx_train, ty_train, input_shape return x_train, y_train, tx_train, ty_train, input_shape
def load_gift_data(textNum = 80): def load_gift_data(textNum = 80, region = "all"):
collect = get_collect() collect = get_collect()
x_train = [] x_train = []
y_train = [] y_train = []
rkey = "gift-" + region
collect_gift = [] collect_gift = []
for k in collect["gift"]: for k in collect[rkey]:
collect_gift.append(collect["gift"][k]) collect_gift.append(collect[rkey][k])
collect_gift.sort(key=lambda x:x[0][2]) collect_gift.sort(key=lambda x:x[0][2])
lastday_v = collect_gift[0] lastday_v = collect_gift[0]

View File

@ -17,7 +17,8 @@ from data import load_pay_data
if __name__ == "__main__": if __name__ == "__main__":
x_train, y_train, tx_train, ty_train, input_shape = load_pay_data(80) region = "Region_Arab"
x_train, y_train, tx_train, ty_train, input_shape = load_pay_data(80, region)
model = Sequential() model = Sequential()
@ -31,7 +32,7 @@ if __name__ == "__main__":
model.compile(loss = 'mse', optimizer = 'adam') model.compile(loss = 'mse', optimizer = 'adam')
model.fit(x_train, y_train, batch_size=96, epochs=1200) model.fit(x_train, y_train, batch_size=96, epochs=1200)
model.save("./predict_pay") model.save("./predict_pay_" + region)
p_data = model.predict(tx_train) p_data = model.predict(tx_train)
for i in range(len(p_data)): for i in range(len(p_data)):