diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f59ec20 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +* \ No newline at end of file diff --git a/data.py b/data.py index 0b99f48..71b4e89 100644 --- a/data.py +++ b/data.py @@ -1,11 +1,8 @@ -import logging -import traceback - from keras.models import Sequential from keras.layers import Dense, Dropout, Embedding from keras.layers import InputLayer from keras.layers import LSTM -from keras import backend +# from keras import backend as K import pymysql import pickle @@ -13,24 +10,6 @@ import os import numpy import time, datetime -regions = [ - "all", - "Region_Arab", - "Region_China", - "Region_English", - "Region_Germany", - "Region_India", - "Region_Indonesia", - "Region_Japan", - "Region_Philippines", - "Region_Portuguese", - "Region_Russian", - "Region_Spanish", - "Region_Thailand", - "Region_Turkey", - "Region_Vietnam", -] - def get_collect(): collect = {} loadfile = "./collect.pickle" @@ -39,126 +18,77 @@ def get_collect(): collect = pickle.load(open(loadfile, 'rb')) except Exception as e: print(e) - try: - # 打开数据库连接 - db = pymysql.connect(host="sg-board1.livenono.com", port=3306,user="root",passwd="Nono-databoard",db="databoard",charset="utf8") - - # 使用 cursor() 方法创建一个游标对象 cursor - cursor = db.cursor() - today = time.strftime("%Y-%m-%d", time.localtime()) - - for region in regions: + db = pymysql.connect(host="sg-board1.livenono.com", port=3306,user="root",passwd="Nono-databoard",db="databoard",charset="utf8") + + # 使用 cursor() 方法创建一个游标对象 cursor + cursor = db.cursor() + today = time.strftime("%Y-%m-%d", time.localtime()) + + # 使用 execute() 方法执行 SQL 查询 + cursor.execute( + '''SELECT coin, extra_coins, pay_users, create_at from pay_items_hour pih where region = "all" and platform="all" and create_at <= %s''', + (today), + ) + + collect_pay = {} - # 使用 execute() 方法执行 SQL 查询 - cursor.execute('''SELECT coin, extra_coins, pay_users, create_at from pay_items_hour pih where region = %s and country = "all" and platform="all" and create_at >= "2021-02-23" and create_at <= %s''',(region , today)) - collect_pay = [] - - for row in cursor.fetchall(): - # print(row) - coin, extra_coins, pay_users, create_at = row - rowlist = [coin, extra_coins, pay_users, create_at.hour, create_at] - # print(dir(create_at), create_at.hour) - collect_pay.append(rowlist) - # d = str(create_at.date()) - # if d in collect_pay: - # collect_pay.append(row) - # else: - # collect_pay[d] = [ row ] - # print(dir(create_at), create_at.timestamp(), create_at.date()) - print('共查找出', cursor.rowcount, '条数据') + for row in cursor.fetchall(): + # print(row) + coin, extra_coins, pay_users, create_at = row + d = str(create_at.date()) + if d in collect_pay: + collect_pay[d].append(row) + else: + collect_pay[d] = [ row ] + # print(dir(create_at), create_at.timestamp(), create_at.date()) + print('共查找出', cursor.rowcount, '条数据') + deletelist = [] + for k in collect_pay: + if len(collect_pay[k]) != 24: + deletelist.append(k) + for k in deletelist: + del collect_pay[k] + querydate= [] + for k in collect_pay: + querydate.append(k) + + querydate.sort() + cursor.execute( + '''SELECT coin, users, create_at from gift_items_hour pih where region = "all" and create_at >= %s and create_at <= %s''', + (querydate[0], querydate[-1]), + ) - # if cursor.rowcount <= 500: - # collect["pay-" + region] = None - # collect["gift-" + region] = None - # continue + collect_gift = {} + for row in cursor.fetchall(): + + coin, users, create_at = row + d = str(create_at.date()) + if d in collect_gift: + collect_gift[d].append(row) + else: + collect_gift[d] = [ row ] - # deletelist = [] - # for k in collect_pay: - # if len(collect_pay[k]) != 24: - # deletelist.append(k) + for k in collect_pay: + l = collect_pay[k] + l.sort(key=lambda x:x[3]) - # for k in deletelist: - # del collect_pay[k] + for k in collect_gift: + l = collect_gift[k] + l.sort(key=lambda x:x[2]) - # querydate= [] - # for k in collect_pay: - # querydate.append(k) - - - - # querydate.sort() - cursor.execute( - '''SELECT coin, users, create_at from gift_items_hour pih where region = %s and country = "all" and create_at >= "2021-02-23" and create_at <= %s''', - (region, today), - ) - - collect_gift = [] - for row in cursor.fetchall(): - - coin, users, create_at = row - rowlist = [coin, users, create_at.hour, create_at] - collect_gift.append(rowlist) - - # d = str(create_at.date()) - # if d in collect_gift: - # collect_gift[d].append(row) - # else: - # collect_gift[d] = [ row ] - - - collect_pay.sort(key=lambda x:x[-1]) - collect_gift.sort(key=lambda x:x[-1]) - # for k in collect_gift: - # l = collect_gift[k] - # l.sort(key=lambda x:x[2]) - yesterday = {} - for v in collect_pay: - print(v[-1]) - date = (v[-1].date() - datetime.timedelta(days=1)).__str__() - print(date) - if date not in yesterday: - cursor.execute( - '''SELECT coin, extra_coins, pay_users, create_at from pay_items_day pid where region = %s and country = "all" and platform="all" and create_at = %s''', - (region , date), - ) - row = cursor.fetchone() - coin, extra_coins, pay_users, create_at = row - yesterday[date] = coin + extra_coins - v.insert(-2, yesterday[date]) - - yesterday = {} - for v in collect_gift: - print(v[-1]) - date = (v[-1].date() - datetime.timedelta(days=1)).__str__() - print(date) - if date not in yesterday: - cursor.execute( - '''SELECT coin, users, create_at from gift_items_day where region = %s and country = "all" and create_at = %s''', - (region , date), - ) - row = cursor.fetchone() - coin, users, create_at = row - yesterday[date] = coin - v.insert(-2, yesterday[date]) - - collect["pay-" + region] = collect_pay - collect["gift-" + region] = collect_gift - except Exception as e: - # print(e) - logging.error(traceback.format_exc()) - + collect["pay"] = collect_pay + collect["gift"] = collect_gift pickle.dump(collect, open(loadfile, 'wb+')) - finally: return collect -def load_pay_data(textNum = 80, region = "all"): +def load_pay_data(textNum = 80): collect = get_collect() @@ -168,13 +98,11 @@ def load_pay_data(textNum = 80, region = "all"): x_train = [] y_train = [] - rkey = "pay-" + region - collect_pay = collect[rkey] - - # for k in collect[rkey]: - # collect_pay.append(collect[rkey][k]) + collect_pay = [] + for k in collect["pay"]: + collect_pay.append(collect["pay"][k]) - # collect_pay.sort(key=lambda x:x[0][3]) + collect_pay.sort(key=lambda x:x[0][3]) lastday_v = collect_pay[0] for cur_v in collect_pay[1:]: @@ -182,6 +110,9 @@ def load_pay_data(textNum = 80, region = "all"): users = 0 last_total_coin = 0 + for v2 in lastday_v: + last_total_coin += v2[0] + v2[1] + count = 0 for v1, v2 in zip(cur_v,lastday_v): total_coin += v1[0] + v1[1] @@ -194,7 +125,7 @@ def load_pay_data(textNum = 80, region = "all"): # print(compare) # 时刻. 前一个小时 时刻. 当前支付总币数. 当前支付总币数 昨天币数 - x_train.append([v1[-2] ,total_coin / v1[-3] , total_coin]) + x_train.append([count ,total_coin / last_total_coin , total_coin]) count+=1 for i in range(count): @@ -210,23 +141,21 @@ def load_pay_data(textNum = 80, region = "all"): tx_train = x_train[len(x_train) - textNum:] ty_train = y_train[len(y_train) - textNum:] - # x_train = x_train[:len(x_train) - textNum] - # y_train = y_train[:len(y_train) - textNum] + x_train = x_train[:len(x_train) - textNum] + y_train = y_train[:len(y_train) - textNum] return x_train, y_train, tx_train, ty_train, input_shape -def load_gift_data(textNum = 80, region = "all"): +def load_gift_data(textNum = 80): collect = get_collect() x_train = [] y_train = [] - - rkey = "gift-" + region collect_gift = [] - for k in collect[rkey]: - collect_gift.append(collect[rkey][k]) + for k in collect["gift"]: + collect_gift.append(collect["gift"][k]) collect_gift.sort(key=lambda x:x[0][2]) lastday_v = collect_gift[0] @@ -268,7 +197,7 @@ def load_gift_data(textNum = 80, region = "all"): tx_train = x_train[len(x_train) - textNum:] ty_train = y_train[len(y_train) - textNum:] - # x_train = x_train[:len(x_train) - textNum] - # y_train = y_train[:len(y_train) - textNum] + x_train = x_train[:len(x_train) - textNum] + y_train = y_train[:len(y_train) - textNum] return x_train, y_train, tx_train, ty_train, input_shape \ No newline at end of file diff --git a/predict.py b/predict.py deleted file mode 100644 index 1fe3804..0000000 --- a/predict.py +++ /dev/null @@ -1,37 +0,0 @@ -import numpy -from keras.models import load_model -from data import load_pay_data, load_gift_data - -import matplotlib.pyplot as plt - - -# x_train, y_train, tx_train, ty_train, _ = load_pay_data(160) -# model = load_model("./predict_pay") - -# p_data = model.predict(tx_train) -# for i in range(len(p_data)): -# comp = (p_data[i][0] - ty_train[i]) / ty_train[i] -# print(comp, p_data[i][0], ty_train[i]) -# if abs(comp) >= 1: -# print("测结果:", p_data[i][0], "测:", tx_train[i], "真实:", ty_train[i]) - - -x_train, y_train, tx_train, ty_train, _ = load_gift_data(160) -model = load_model("./predict_gift") -p_data = model.predict(tx_train) -for i in range(len(p_data)): - comp = (p_data[i][0] - ty_train[i]) / ty_train[i] - print(comp, p_data[i][0], ty_train[i]) - if abs(comp) >= 0.1: - print("测结果:", p_data[i][0], "测:", tx_train[i], "真实:", ty_train[i]) - - -plt.plot(ty_train) -plt.plot(p_data) -plt.show() -# data = numpy.reshape([[15, 2359688 / 10000000, 255968 / 1000000, 10 / 10000]],(1, 4, 1)) -# print( model.predict(data)) - - - - \ No newline at end of file diff --git a/predict_gift.py b/predict_gift.py new file mode 100644 index 0000000..b1dfbde --- /dev/null +++ b/predict_gift.py @@ -0,0 +1,25 @@ +import numpy +from keras.models import load_model +from data import load_pay_data, load_gift_data + +import matplotlib.pyplot as plt + + +x_train, y_train, tx_train, ty_train, _ = load_gift_data(160) +model = load_model("./predict_gift") +p_data = model.predict(tx_train) +for i in range(len(p_data)): + comp = (p_data[i][0] - ty_train[i]) / ty_train[i] + print(comp, p_data[i][0], ty_train[i]) + if abs(comp) >= 0.1: + print("测结果:", p_data[i][0], "测:", tx_train[i], "真实:", ty_train[i]) + + +plt.plot(ty_train) +plt.plot(p_data) +plt.show() + + + + + \ No newline at end of file diff --git a/predict_pay.py b/predict_pay.py new file mode 100644 index 0000000..0a82827 --- /dev/null +++ b/predict_pay.py @@ -0,0 +1,20 @@ +import numpy +from keras.models import load_model +from data import load_pay_data, load_gift_data + +import matplotlib.pyplot as plt + + +x_train, y_train, tx_train, ty_train, _ = load_pay_data(160) +model = load_model("./predict_pay") + +p_data = model.predict(tx_train) +for i in range(len(p_data)): + comp = (p_data[i][0] - ty_train[i]) / ty_train[i] + print(comp, p_data[i][0], ty_train[i]) + if abs(comp) >= 1: + print("测结果:", p_data[i][0], "测:", tx_train[i], "真实:", ty_train[i]) + +plt.plot(ty_train) +plt.plot(p_data) +plt.show() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 886ea59..bb144e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ tensorflow keras numpy pymysql -grpc_tools \ No newline at end of file +grpcio +grpc_tools \ No newline at end of file diff --git a/train_pay.py b/train_pay.py index 335bb12..190fb59 100644 --- a/train_pay.py +++ b/train_pay.py @@ -4,7 +4,8 @@ from keras.models import Sequential from keras.layers import Dense, Dropout, Embedding from keras.layers import InputLayer from keras.layers import LSTM -from keras import backend +from keras import backend as K +from keras.losses import mean_squared_error from keras.layers.recurrent import SimpleRNN import pymysql @@ -14,11 +15,12 @@ import numpy from data import load_pay_data +def mean_squared_error(y_true, y_pred): + return K.mean(K.square(y_pred - y_true), axis=-1) if __name__ == "__main__": - region = "Region_Arab" - x_train, y_train, tx_train, ty_train, input_shape = load_pay_data(80, region) + x_train, y_train, tx_train, ty_train, input_shape = load_pay_data(80) model = Sequential() @@ -29,10 +31,10 @@ if __name__ == "__main__": # model.add(Dropout(0.1)) model.add(Dense(1)) model.summary() - model.compile(loss = 'mse', optimizer = 'adam') + model.compile(loss = 'msle', optimizer = 'adam') - model.fit(x_train, y_train, batch_size=96, epochs=1200) - model.save("./predict_pay_" + region) + model.fit(x_train, y_train, batch_size=96, epochs=600) + model.save("./predict_pay") p_data = model.predict(tx_train) for i in range(len(p_data)):