diff --git a/data.py b/data.py index 5771885..0b99f48 100644 --- a/data.py +++ b/data.py @@ -1,3 +1,6 @@ +import logging +import traceback + from keras.models import Sequential from keras.layers import Dense, Dropout, Embedding from keras.layers import InputLayer @@ -10,7 +13,25 @@ import os import numpy import time, datetime -def get_collect(region = "all"): +regions = [ + "all", + "Region_Arab", + "Region_China", + "Region_English", + "Region_Germany", + "Region_India", + "Region_Indonesia", + "Region_Japan", + "Region_Philippines", + "Region_Portuguese", + "Region_Russian", + "Region_Spanish", + "Region_Thailand", + "Region_Turkey", + "Region_Vietnam", +] + +def get_collect(): collect = {} loadfile = "./collect.pickle" @@ -18,66 +39,120 @@ def get_collect(region = "all"): collect = pickle.load(open(loadfile, 'rb')) except Exception as e: print(e) + try: + # 打开数据库连接 - db = pymysql.connect(host="sg-board1.livenono.com", port=3306,user="root",passwd="Nono-databoard",db="databoard",charset="utf8") - - # 使用 cursor() 方法创建一个游标对象 cursor - cursor = db.cursor() - today = time.strftime("%Y-%m-%d", time.localtime()) - - # 使用 execute() 方法执行 SQL 查询 - cursor.execute('''SELECT coin, extra_coins, pay_users, create_at from pay_items_hour pih where region = "%s" and platform="all" and create_at <= %s''',region , today) - collect_pay = {} - for row in cursor.fetchall(): - # print(row) - coin, extra_coins, pay_users, create_at = row - d = str(create_at.date()) - if d in collect_pay: - collect_pay[d].append(row) - else: - collect_pay[d] = [ row ] - # print(dir(create_at), create_at.timestamp(), create_at.date()) - print('共查找出', cursor.rowcount, '条数据') - deletelist = [] - for k in collect_pay: - if len(collect_pay[k]) != 24: - deletelist.append(k) + db = pymysql.connect(host="sg-board1.livenono.com", port=3306,user="root",passwd="Nono-databoard",db="databoard",charset="utf8") + + # 使用 cursor() 方法创建一个游标对象 cursor + cursor = db.cursor() + today = time.strftime("%Y-%m-%d", time.localtime()) + + for region in regions: - for k in deletelist: - del collect_pay[k] + # 使用 execute() 方法执行 SQL 查询 + cursor.execute('''SELECT coin, extra_coins, pay_users, create_at from pay_items_hour pih where region = %s and country = "all" and platform="all" and create_at >= "2021-02-23" and create_at <= %s''',(region , today)) + collect_pay = [] + + for row in cursor.fetchall(): + # print(row) + coin, extra_coins, pay_users, create_at = row + rowlist = [coin, extra_coins, pay_users, create_at.hour, create_at] + # print(dir(create_at), create_at.hour) + collect_pay.append(rowlist) + # d = str(create_at.date()) + # if d in collect_pay: + # collect_pay.append(row) + # else: + # collect_pay[d] = [ row ] + # print(dir(create_at), create_at.timestamp(), create_at.date()) + print('共查找出', cursor.rowcount, '条数据') - querydate= [] - for k in collect_pay: - querydate.append(k) - - querydate.sort() - cursor.execute( - '''SELECT coin, users, create_at from gift_items_hour pih where region = "all" and create_at >= %s and create_at <= %s''', - (querydate[0], querydate[-1]), - ) - collect_gift = {} - for row in cursor.fetchall(): - - coin, users, create_at = row - d = str(create_at.date()) - if d in collect_gift: - collect_gift[d].append(row) - else: - collect_gift[d] = [ row ] - for k in collect_pay: - l = collect_pay[k] - l.sort(key=lambda x:x[3]) + # if cursor.rowcount <= 500: + # collect["pay-" + region] = None + # collect["gift-" + region] = None + # continue - for k in collect_gift: - l = collect_gift[k] - l.sort(key=lambda x:x[2]) + # deletelist = [] + # for k in collect_pay: + # if len(collect_pay[k]) != 24: + # deletelist.append(k) - collect["pay"] = collect_pay - collect["gift"] = collect_gift + # for k in deletelist: + # del collect_pay[k] + + # querydate= [] + # for k in collect_pay: + # querydate.append(k) + + + + # querydate.sort() + cursor.execute( + '''SELECT coin, users, create_at from gift_items_hour pih where region = %s and country = "all" and create_at >= "2021-02-23" and create_at <= %s''', + (region, today), + ) + + collect_gift = [] + for row in cursor.fetchall(): + + coin, users, create_at = row + rowlist = [coin, users, create_at.hour, create_at] + collect_gift.append(rowlist) + + # d = str(create_at.date()) + # if d in collect_gift: + # collect_gift[d].append(row) + # else: + # collect_gift[d] = [ row ] + + + collect_pay.sort(key=lambda x:x[-1]) + collect_gift.sort(key=lambda x:x[-1]) + # for k in collect_gift: + # l = collect_gift[k] + # l.sort(key=lambda x:x[2]) + yesterday = {} + for v in collect_pay: + print(v[-1]) + date = (v[-1].date() - datetime.timedelta(days=1)).__str__() + print(date) + if date not in yesterday: + cursor.execute( + '''SELECT coin, extra_coins, pay_users, create_at from pay_items_day pid where region = %s and country = "all" and platform="all" and create_at = %s''', + (region , date), + ) + row = cursor.fetchone() + coin, extra_coins, pay_users, create_at = row + yesterday[date] = coin + extra_coins + v.insert(-2, yesterday[date]) + + yesterday = {} + for v in collect_gift: + print(v[-1]) + date = (v[-1].date() - datetime.timedelta(days=1)).__str__() + print(date) + if date not in yesterday: + cursor.execute( + '''SELECT coin, users, create_at from gift_items_day where region = %s and country = "all" and create_at = %s''', + (region , date), + ) + row = cursor.fetchone() + coin, users, create_at = row + yesterday[date] = coin + v.insert(-2, yesterday[date]) + + collect["pay-" + region] = collect_pay + collect["gift-" + region] = collect_gift + except Exception as e: + # print(e) + logging.error(traceback.format_exc()) + pickle.dump(collect, open(loadfile, 'wb+')) + finally: return collect @@ -93,11 +168,13 @@ def load_pay_data(textNum = 80, region = "all"): x_train = [] y_train = [] - collect_pay = [] - for k in collect["pay"]: - collect_pay.append(collect["pay"][k]) + rkey = "pay-" + region + collect_pay = collect[rkey] + + # for k in collect[rkey]: + # collect_pay.append(collect[rkey][k]) - collect_pay.sort(key=lambda x:x[0][3]) + # collect_pay.sort(key=lambda x:x[0][3]) lastday_v = collect_pay[0] for cur_v in collect_pay[1:]: @@ -105,9 +182,6 @@ def load_pay_data(textNum = 80, region = "all"): users = 0 last_total_coin = 0 - for v2 in lastday_v: - last_total_coin += v2[0] + v2[1] - count = 0 for v1, v2 in zip(cur_v,lastday_v): total_coin += v1[0] + v1[1] @@ -120,7 +194,7 @@ def load_pay_data(textNum = 80, region = "all"): # print(compare) # 时刻. 前一个小时 时刻. 当前支付总币数. 当前支付总币数 昨天币数 - x_train.append([count ,total_coin / last_total_coin , total_coin]) + x_train.append([v1[-2] ,total_coin / v1[-3] , total_coin]) count+=1 for i in range(count): @@ -141,16 +215,18 @@ def load_pay_data(textNum = 80, region = "all"): return x_train, y_train, tx_train, ty_train, input_shape -def load_gift_data(textNum = 80): +def load_gift_data(textNum = 80, region = "all"): collect = get_collect() x_train = [] y_train = [] + + rkey = "gift-" + region collect_gift = [] - for k in collect["gift"]: - collect_gift.append(collect["gift"][k]) + for k in collect[rkey]: + collect_gift.append(collect[rkey][k]) collect_gift.sort(key=lambda x:x[0][2]) lastday_v = collect_gift[0] diff --git a/train_pay.py b/train_pay.py index 1c205c2..335bb12 100644 --- a/train_pay.py +++ b/train_pay.py @@ -17,7 +17,8 @@ from data import load_pay_data if __name__ == "__main__": - x_train, y_train, tx_train, ty_train, input_shape = load_pay_data(80) + region = "Region_Arab" + x_train, y_train, tx_train, ty_train, input_shape = load_pay_data(80, region) model = Sequential() @@ -31,7 +32,7 @@ if __name__ == "__main__": model.compile(loss = 'mse', optimizer = 'adam') model.fit(x_train, y_train, batch_size=96, epochs=1200) - model.save("./predict_pay") + model.save("./predict_pay_" + region) p_data = model.predict(tx_train) for i in range(len(p_data)):