This commit is contained in:
eson 2021-04-02 18:40:15 +08:00
parent 5d3130169e
commit 65bb5b387d
3 changed files with 13 additions and 9 deletions

12
data.py
View File

@ -106,14 +106,14 @@ def load_pay_data(textNum = 80):
lastday_v = collect_pay[0]
for cur_v in collect_pay[1:]:
total_coin = 0
users = 0
last_total_coin = 0
total_coin = 0.0
users = 0.0
last_total_coin = 0.0
for v2 in lastday_v:
last_total_coin += v2[0] + v2[1]
count = 0
count = 0.0
for v1, v2 in zip(cur_v,lastday_v):
total_coin += v1[0] + v1[1]
users += v1[2]
@ -125,10 +125,10 @@ def load_pay_data(textNum = 80):
# print(compare)
# 时刻. 前一个小时 时刻. 当前支付总币数. 当前支付总币数 昨天币数
x_train.append([count ,total_coin / last_total_coin , total_coin])
x_train.append([count, total_coin, last_total_coin ])
count+=1
for i in range(count):
for i in range(int(count)):
y_train.append(total_coin)
lastday_v = cur_v

View File

@ -2,10 +2,12 @@ import numpy
from keras.models import load_model
from data import load_pay_data, load_gift_data
import matplotlib
import matplotlib.pyplot as plt
x_train, y_train, tx_train, ty_train, _ = load_pay_data(160)
x_train, y_train, tx_train, ty_train, _ = load_pay_data(320)
model = load_model("./predict_pay")
p_data = model.predict(tx_train)

View File

@ -16,6 +16,8 @@ import numpy
from data import load_pay_data
def mean_squared_error(y_true, y_pred):
print(dir(y_true), y_true.consumers)
print(y_true, y_pred)
return K.mean(K.square(y_pred - y_true), axis=-1)
if __name__ == "__main__":
@ -31,9 +33,9 @@ if __name__ == "__main__":
# model.add(Dropout(0.1))
model.add(Dense(1))
model.summary()
model.compile(loss = 'msle', optimizer = 'adam')
model.compile(loss = "mse", optimizer = 'adam')
model.fit(x_train, y_train, batch_size=96, epochs=600)
model.fit(x_train, y_train, batch_size=128, epochs=500)
model.save("./predict_pay")
p_data = model.predict(tx_train)