feat(todo): 完善测试的算法

This commit is contained in:
eson 2021-03-26 17:14:06 +08:00
parent fbcbeca35d
commit a7ce020942
2 changed files with 142 additions and 65 deletions

202
data.py
View File

@ -1,3 +1,6 @@
import logging
import traceback
from keras.models import Sequential
from keras.layers import Dense, Dropout, Embedding
from keras.layers import InputLayer
@ -10,7 +13,25 @@ import os
import numpy
import time, datetime
def get_collect(region = "all"):
regions = [
"all",
"Region_Arab",
"Region_China",
"Region_English",
"Region_Germany",
"Region_India",
"Region_Indonesia",
"Region_Japan",
"Region_Philippines",
"Region_Portuguese",
"Region_Russian",
"Region_Spanish",
"Region_Thailand",
"Region_Turkey",
"Region_Vietnam",
]
def get_collect():
collect = {}
loadfile = "./collect.pickle"
@ -18,66 +39,120 @@ def get_collect(region = "all"):
collect = pickle.load(open(loadfile, 'rb'))
except Exception as e:
print(e)
try:
# 打开数据库连接
db = pymysql.connect(host="sg-board1.livenono.com", port=3306,user="root",passwd="Nono-databoard",db="databoard",charset="utf8")
# 使用 cursor() 方法创建一个游标对象 cursor
cursor = db.cursor()
today = time.strftime("%Y-%m-%d", time.localtime())
# 使用 execute() 方法执行 SQL 查询
cursor.execute('''SELECT coin, extra_coins, pay_users, create_at from pay_items_hour pih where region = "%s" and platform="all" and create_at <= %s''',region , today)
collect_pay = {}
for row in cursor.fetchall():
# print(row)
coin, extra_coins, pay_users, create_at = row
d = str(create_at.date())
if d in collect_pay:
collect_pay[d].append(row)
else:
collect_pay[d] = [ row ]
# print(dir(create_at), create_at.timestamp(), create_at.date())
print('共查找出', cursor.rowcount, '条数据')
deletelist = []
for k in collect_pay:
if len(collect_pay[k]) != 24:
deletelist.append(k)
db = pymysql.connect(host="sg-board1.livenono.com", port=3306,user="root",passwd="Nono-databoard",db="databoard",charset="utf8")
# 使用 cursor() 方法创建一个游标对象 cursor
cursor = db.cursor()
today = time.strftime("%Y-%m-%d", time.localtime())
for region in regions:
for k in deletelist:
del collect_pay[k]
# 使用 execute() 方法执行 SQL 查询
cursor.execute('''SELECT coin, extra_coins, pay_users, create_at from pay_items_hour pih where region = %s and country = "all" and platform="all" and create_at >= "2021-02-23" and create_at <= %s''',(region , today))
collect_pay = []
for row in cursor.fetchall():
# print(row)
coin, extra_coins, pay_users, create_at = row
rowlist = [coin, extra_coins, pay_users, create_at.hour, create_at]
# print(dir(create_at), create_at.hour)
collect_pay.append(rowlist)
# d = str(create_at.date())
# if d in collect_pay:
# collect_pay.append(row)
# else:
# collect_pay[d] = [ row ]
# print(dir(create_at), create_at.timestamp(), create_at.date())
print('共查找出', cursor.rowcount, '条数据')
querydate= []
for k in collect_pay:
querydate.append(k)
querydate.sort()
cursor.execute(
'''SELECT coin, users, create_at from gift_items_hour pih where region = "all" and create_at >= %s and create_at <= %s''',
(querydate[0], querydate[-1]),
)
collect_gift = {}
for row in cursor.fetchall():
coin, users, create_at = row
d = str(create_at.date())
if d in collect_gift:
collect_gift[d].append(row)
else:
collect_gift[d] = [ row ]
for k in collect_pay:
l = collect_pay[k]
l.sort(key=lambda x:x[3])
# if cursor.rowcount <= 500:
# collect["pay-" + region] = None
# collect["gift-" + region] = None
# continue
for k in collect_gift:
l = collect_gift[k]
l.sort(key=lambda x:x[2])
# deletelist = []
# for k in collect_pay:
# if len(collect_pay[k]) != 24:
# deletelist.append(k)
collect["pay"] = collect_pay
collect["gift"] = collect_gift
# for k in deletelist:
# del collect_pay[k]
# querydate= []
# for k in collect_pay:
# querydate.append(k)
# querydate.sort()
cursor.execute(
'''SELECT coin, users, create_at from gift_items_hour pih where region = %s and country = "all" and create_at >= "2021-02-23" and create_at <= %s''',
(region, today),
)
collect_gift = []
for row in cursor.fetchall():
coin, users, create_at = row
rowlist = [coin, users, create_at.hour, create_at]
collect_gift.append(rowlist)
# d = str(create_at.date())
# if d in collect_gift:
# collect_gift[d].append(row)
# else:
# collect_gift[d] = [ row ]
collect_pay.sort(key=lambda x:x[-1])
collect_gift.sort(key=lambda x:x[-1])
# for k in collect_gift:
# l = collect_gift[k]
# l.sort(key=lambda x:x[2])
yesterday = {}
for v in collect_pay:
print(v[-1])
date = (v[-1].date() - datetime.timedelta(days=1)).__str__()
print(date)
if date not in yesterday:
cursor.execute(
'''SELECT coin, extra_coins, pay_users, create_at from pay_items_day pid where region = %s and country = "all" and platform="all" and create_at = %s''',
(region , date),
)
row = cursor.fetchone()
coin, extra_coins, pay_users, create_at = row
yesterday[date] = coin + extra_coins
v.insert(-2, yesterday[date])
yesterday = {}
for v in collect_gift:
print(v[-1])
date = (v[-1].date() - datetime.timedelta(days=1)).__str__()
print(date)
if date not in yesterday:
cursor.execute(
'''SELECT coin, users, create_at from gift_items_day where region = %s and country = "all" and create_at = %s''',
(region , date),
)
row = cursor.fetchone()
coin, users, create_at = row
yesterday[date] = coin
v.insert(-2, yesterday[date])
collect["pay-" + region] = collect_pay
collect["gift-" + region] = collect_gift
except Exception as e:
# print(e)
logging.error(traceback.format_exc())
pickle.dump(collect, open(loadfile, 'wb+'))
finally:
return collect
@ -93,11 +168,13 @@ def load_pay_data(textNum = 80, region = "all"):
x_train = []
y_train = []
collect_pay = []
for k in collect["pay"]:
collect_pay.append(collect["pay"][k])
rkey = "pay-" + region
collect_pay = collect[rkey]
# for k in collect[rkey]:
# collect_pay.append(collect[rkey][k])
collect_pay.sort(key=lambda x:x[0][3])
# collect_pay.sort(key=lambda x:x[0][3])
lastday_v = collect_pay[0]
for cur_v in collect_pay[1:]:
@ -105,9 +182,6 @@ def load_pay_data(textNum = 80, region = "all"):
users = 0
last_total_coin = 0
for v2 in lastday_v:
last_total_coin += v2[0] + v2[1]
count = 0
for v1, v2 in zip(cur_v,lastday_v):
total_coin += v1[0] + v1[1]
@ -120,7 +194,7 @@ def load_pay_data(textNum = 80, region = "all"):
# print(compare)
# 时刻. 前一个小时 时刻. 当前支付总币数. 当前支付总币数 昨天币数
x_train.append([count ,total_coin / last_total_coin , total_coin])
x_train.append([v1[-2] ,total_coin / v1[-3] , total_coin])
count+=1
for i in range(count):
@ -141,16 +215,18 @@ def load_pay_data(textNum = 80, region = "all"):
return x_train, y_train, tx_train, ty_train, input_shape
def load_gift_data(textNum = 80):
def load_gift_data(textNum = 80, region = "all"):
collect = get_collect()
x_train = []
y_train = []
rkey = "gift-" + region
collect_gift = []
for k in collect["gift"]:
collect_gift.append(collect["gift"][k])
for k in collect[rkey]:
collect_gift.append(collect[rkey][k])
collect_gift.sort(key=lambda x:x[0][2])
lastday_v = collect_gift[0]

View File

@ -17,7 +17,8 @@ from data import load_pay_data
if __name__ == "__main__":
x_train, y_train, tx_train, ty_train, input_shape = load_pay_data(80)
region = "Region_Arab"
x_train, y_train, tx_train, ty_train, input_shape = load_pay_data(80, region)
model = Sequential()
@ -31,7 +32,7 @@ if __name__ == "__main__":
model.compile(loss = 'mse', optimizer = 'adam')
model.fit(x_train, y_train, batch_size=96, epochs=1200)
model.save("./predict_pay")
model.save("./predict_pay_" + region)
p_data = model.predict(tx_train)
for i in range(len(p_data)):