diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..975bc9d --- /dev/null +++ b/__init__.py @@ -0,0 +1,4 @@ +import sys +import os + +sys.path.append(".") diff --git a/api/predict.proto b/api/predict.proto new file mode 100644 index 0000000..3e52081 --- /dev/null +++ b/api/predict.proto @@ -0,0 +1,39 @@ +syntax = "proto3"; + +service Predict { + rpc PayDay(RequestPay) returns (ReplyPay) {} + rpc GiftDay(RequestGift) returns (ReplyGift) {} +} + +// Request 统一Request请求 +message RequestPay { + // count, total_coin / last_total_coin, total_coin // 24小时 + int32 Hour = 1; + int64 Coin = 2; + int64 YesterdayCoin = 3; +} + +// RequestGift 统一Request请求 +message RequestGift { + // count, total_coin / last_total_coin, total_coin // 24小时 + int32 Hour = 1; + int64 Coin = 2; + int64 YesterdayCoin = 3; +} + +// ReplyHeader 统一ReplyHeader 响应 +message ReplyHeader { + int32 Code = 1; + string Message = 2; +} + +message ReplyPay { + ReplyHeader Header = 1; + int64 Result = 2; +} + + +message ReplyGift { + ReplyHeader Header = 1; + int64 Result = 2; +} \ No newline at end of file diff --git a/data.py b/data.py index ab2681a..a7d4dcc 100644 --- a/data.py +++ b/data.py @@ -81,9 +81,10 @@ def get_collect(): finally: return collect - - + + def load_pay_data(textNum = 80): + collect = get_collect() @@ -101,6 +102,7 @@ def load_pay_data(textNum = 80): for cur_v in collect_pay[1:]: total_coin = 0 + users = 0 last_total_coin = 0 for v2 in lastday_v: @@ -109,14 +111,16 @@ def load_pay_data(textNum = 80): count = 0 for v1, v2 in zip(cur_v,lastday_v): total_coin += v1[0] + v1[1] + users += v1[2] # print(v1[3]) # last_total_coin += v2[0] + v2[1] # print(v2[3]) # compare = float(total_coin - last_total_coin) / float(last_total_coin) # print(compare) - - x_train.append([count, total_coin, total_coin/last_total_coin]) + + # 时刻. 前一个小时 时刻. 当前支付总币数. 当前支付总币数 昨天币数 + x_train.append([count ,total_coin / last_total_coin , total_coin]) count+=1 for i in range(count): @@ -124,7 +128,8 @@ def load_pay_data(textNum = 80): lastday_v = cur_v - x_train = numpy.reshape(x_train, (len(x_train) , 3, 1)) + input_shape = (len(x_train[0]), 1) + x_train = numpy.reshape(x_train, (len(x_train) , input_shape[0], input_shape[1])) y_train = numpy.reshape(y_train, (len(y_train))) # max_features = 1024 @@ -134,7 +139,7 @@ def load_pay_data(textNum = 80): # x_train = x_train[:len(x_train) - textNum] # y_train = y_train[:len(y_train) - textNum] - return x_train, y_train, tx_train, ty_train + return x_train, y_train, tx_train, ty_train, input_shape def load_gift_data(textNum = 80): @@ -159,7 +164,7 @@ def load_gift_data(textNum = 80): last_total_coin += v2[0] f = 20000000.0 - count = 1 + count = 0 for v1, v2 in zip(cur_v,lastday_v): total_coin += v1[0] # print(v1[3]) @@ -170,16 +175,17 @@ def load_gift_data(textNum = 80): # print(v2[3]) # compare = float(total_coin - last_total_coin) / float(last_total_coin) # print(compare) - - x_train.append([count, total_coin, total_coin / last_total_coin, users ]) + # 参数 前一小个小时. 时刻. 当前金钱. 送礼人数 + x_train.append([count, total_coin / last_total_coin, total_coin ]) count+=1 - for i in range(count - 1): + for i in range(count): y_train.append(total_coin) lastday_v = cur_v - x_train = numpy.reshape(x_train, (len(x_train) , 4, 1)) + input_shape = (len(x_train[0]), 1) + x_train = numpy.reshape(x_train, (len(x_train) , input_shape[0], input_shape[1])) y_train = numpy.reshape(y_train, (len(y_train))) # max_features = 1024 @@ -189,4 +195,4 @@ def load_gift_data(textNum = 80): # x_train = x_train[:len(x_train) - textNum] # y_train = y_train[:len(y_train) - textNum] - return x_train, y_train, tx_train, ty_train \ No newline at end of file + return x_train, y_train, tx_train, ty_train, input_shape \ No newline at end of file diff --git a/example.py b/example.py deleted file mode 100644 index 0519ecb..0000000 --- a/example.py +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/gen_proto3.sh b/gen_proto3.sh new file mode 100644 index 0000000..c92efaa --- /dev/null +++ b/gen_proto3.sh @@ -0,0 +1,3 @@ +#! /bin/bash +PBPATH=./api +python -m grpc_tools.protoc -I$PBPATH --python_out=. --grpc_python_out=. $PBPATH/*.proto diff --git a/grpc_client.py b/grpc_client.py new file mode 100644 index 0000000..53f4d26 --- /dev/null +++ b/grpc_client.py @@ -0,0 +1,20 @@ +import grpc +import logging +import predict_pb2, predict_pb2_grpc + + +def run(): + option = [('grpc.keepalive_timeout_ms', 10000)] + with grpc.insecure_channel(target='localhost:50051', options=option) as channel: + stub = predict_pb2_grpc.PredictStub(channel) + + response = stub.PayDay( predict_pb2.RequestPay(Hour=0,Coin=100,YesterdayCoin=100), timeout=10) + print(response, type(response.Header.Code)) + + response = stub.GiftDay(predict_pb2.RequestGift(Hour=0,Coin=100,YesterdayCoin=100), timeout=10) + print(response, type(response.Header.Code)) + + +if __name__ == '__main__': + logging.basicConfig() + run() \ No newline at end of file diff --git a/grpc_server.py b/grpc_server.py new file mode 100644 index 0000000..0c9efdf --- /dev/null +++ b/grpc_server.py @@ -0,0 +1,39 @@ + + +import grpc +import logging +from concurrent import futures + +import predict_pb2, predict_pb2_grpc +from predict_pb2 import * +from predict_pb2_grpc import * + +import numpy +from keras.models import load_model +from data import load_pay_data, load_gift_data +pay_model = load_model("./predict_pay") +gift_model = load_model("./predict_gift") + +class Predict(predict_pb2_grpc.PredictServicer): + def PayDay(self, request, context): + # print(request) + inputx = numpy.reshape([request.Hour, request.Coin / request.YesterdayCoin, request.Coin], (1, 3, 1)) + predict_value = pay_model.predict(inputx) + return ReplyPay(Header=ReplyHeader(Code=0, Message=""), Result=int(predict_value[0][0])) + # return super().PayDay(request, context) + def GiftDay(self, request, context): + # print(request) + inputx = numpy.reshape([request.Hour, request.Coin / request.YesterdayCoin, request.Coin], (1, 3, 1)) + predict_value = gift_model.predict(inputx) + return ReplyGift(Header=ReplyHeader(Code=0, Message=""), Result=int(predict_value[0][0])) + +def server(): + grpc_server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) + predict_pb2_grpc.add_PredictServicer_to_server(Predict(), grpc_server) + grpc_server.add_insecure_port('[::]:50051') + grpc_server.start() + grpc_server.wait_for_termination() + +if __name__ == '__main__': + logging.basicConfig() + server() \ No newline at end of file diff --git a/predict.proto b/predict.proto deleted file mode 100644 index 4158aaa..0000000 --- a/predict.proto +++ /dev/null @@ -1,24 +0,0 @@ -syntax = "proto3"; - -service Predict { - rpc PayDay(RequestPay) returns (Reply) {} - rpc GiftDay(RequestGift) returns (Reply) {} -} - -// Request 统一Request请求 -message RequestPay { - // count, total_coin , (total_coin - last_total_coin) , v1[2] , v2[2]] // 24小时 - int32 Hour = 1; - int64 Coin = 2; - int64 YesterdayCoin = 3; -} - -// RequestGift 统一Request请求 -message RequestGift { - -} - -// Reply 统一Reply 响应 -message Reply { - -} \ No newline at end of file diff --git a/predict.py b/predict.py index d910cb1..1fe3804 100644 --- a/predict.py +++ b/predict.py @@ -5,7 +5,7 @@ from data import load_pay_data, load_gift_data import matplotlib.pyplot as plt -# x_train, y_train, tx_train, ty_train = load_pay_data(160) +# x_train, y_train, tx_train, ty_train, _ = load_pay_data(160) # model = load_model("./predict_pay") # p_data = model.predict(tx_train) @@ -16,7 +16,7 @@ import matplotlib.pyplot as plt # print("测结果:", p_data[i][0], "测:", tx_train[i], "真实:", ty_train[i]) -x_train, y_train, tx_train, ty_train = load_gift_data(160) +x_train, y_train, tx_train, ty_train, _ = load_gift_data(160) model = load_model("./predict_gift") p_data = model.predict(tx_train) for i in range(len(p_data)): diff --git a/predict_pb2.py b/predict_pb2.py new file mode 100644 index 0000000..304bb70 --- /dev/null +++ b/predict_pb2.py @@ -0,0 +1,317 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: predict.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='predict.proto', + package='', + syntax='proto3', + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n\rpredict.proto\"?\n\nRequestPay\x12\x0c\n\x04Hour\x18\x01 \x01(\x05\x12\x0c\n\x04\x43oin\x18\x02 \x01(\x03\x12\x15\n\rYesterdayCoin\x18\x03 \x01(\x03\"@\n\x0bRequestGift\x12\x0c\n\x04Hour\x18\x01 \x01(\x05\x12\x0c\n\x04\x43oin\x18\x02 \x01(\x03\x12\x15\n\rYesterdayCoin\x18\x03 \x01(\x03\",\n\x0bReplyHeader\x12\x0c\n\x04\x43ode\x18\x01 \x01(\x05\x12\x0f\n\x07Message\x18\x02 \x01(\t\"8\n\x08ReplyPay\x12\x1c\n\x06Header\x18\x01 \x01(\x0b\x32\x0c.ReplyHeader\x12\x0e\n\x06Result\x18\x02 \x01(\x03\"9\n\tReplyGift\x12\x1c\n\x06Header\x18\x01 \x01(\x0b\x32\x0c.ReplyHeader\x12\x0e\n\x06Result\x18\x02 \x01(\x03\x32T\n\x07Predict\x12\"\n\x06PayDay\x12\x0b.RequestPay\x1a\t.ReplyPay\"\x00\x12%\n\x07GiftDay\x12\x0c.RequestGift\x1a\n.ReplyGift\"\x00\x62\x06proto3' +) + + + + +_REQUESTPAY = _descriptor.Descriptor( + name='RequestPay', + full_name='RequestPay', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='Hour', full_name='RequestPay.Hour', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='Coin', full_name='RequestPay.Coin', index=1, + number=2, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='YesterdayCoin', full_name='RequestPay.YesterdayCoin', index=2, + number=3, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=17, + serialized_end=80, +) + + +_REQUESTGIFT = _descriptor.Descriptor( + name='RequestGift', + full_name='RequestGift', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='Hour', full_name='RequestGift.Hour', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='Coin', full_name='RequestGift.Coin', index=1, + number=2, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='YesterdayCoin', full_name='RequestGift.YesterdayCoin', index=2, + number=3, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=82, + serialized_end=146, +) + + +_REPLYHEADER = _descriptor.Descriptor( + name='ReplyHeader', + full_name='ReplyHeader', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='Code', full_name='ReplyHeader.Code', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='Message', full_name='ReplyHeader.Message', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=148, + serialized_end=192, +) + + +_REPLYPAY = _descriptor.Descriptor( + name='ReplyPay', + full_name='ReplyPay', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='Header', full_name='ReplyPay.Header', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='Result', full_name='ReplyPay.Result', index=1, + number=2, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=194, + serialized_end=250, +) + + +_REPLYGIFT = _descriptor.Descriptor( + name='ReplyGift', + full_name='ReplyGift', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='Header', full_name='ReplyGift.Header', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='Result', full_name='ReplyGift.Result', index=1, + number=2, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=252, + serialized_end=309, +) + +_REPLYPAY.fields_by_name['Header'].message_type = _REPLYHEADER +_REPLYGIFT.fields_by_name['Header'].message_type = _REPLYHEADER +DESCRIPTOR.message_types_by_name['RequestPay'] = _REQUESTPAY +DESCRIPTOR.message_types_by_name['RequestGift'] = _REQUESTGIFT +DESCRIPTOR.message_types_by_name['ReplyHeader'] = _REPLYHEADER +DESCRIPTOR.message_types_by_name['ReplyPay'] = _REPLYPAY +DESCRIPTOR.message_types_by_name['ReplyGift'] = _REPLYGIFT +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +RequestPay = _reflection.GeneratedProtocolMessageType('RequestPay', (_message.Message,), { + 'DESCRIPTOR' : _REQUESTPAY, + '__module__' : 'predict_pb2' + # @@protoc_insertion_point(class_scope:RequestPay) + }) +_sym_db.RegisterMessage(RequestPay) + +RequestGift = _reflection.GeneratedProtocolMessageType('RequestGift', (_message.Message,), { + 'DESCRIPTOR' : _REQUESTGIFT, + '__module__' : 'predict_pb2' + # @@protoc_insertion_point(class_scope:RequestGift) + }) +_sym_db.RegisterMessage(RequestGift) + +ReplyHeader = _reflection.GeneratedProtocolMessageType('ReplyHeader', (_message.Message,), { + 'DESCRIPTOR' : _REPLYHEADER, + '__module__' : 'predict_pb2' + # @@protoc_insertion_point(class_scope:ReplyHeader) + }) +_sym_db.RegisterMessage(ReplyHeader) + +ReplyPay = _reflection.GeneratedProtocolMessageType('ReplyPay', (_message.Message,), { + 'DESCRIPTOR' : _REPLYPAY, + '__module__' : 'predict_pb2' + # @@protoc_insertion_point(class_scope:ReplyPay) + }) +_sym_db.RegisterMessage(ReplyPay) + +ReplyGift = _reflection.GeneratedProtocolMessageType('ReplyGift', (_message.Message,), { + 'DESCRIPTOR' : _REPLYGIFT, + '__module__' : 'predict_pb2' + # @@protoc_insertion_point(class_scope:ReplyGift) + }) +_sym_db.RegisterMessage(ReplyGift) + + + +_PREDICT = _descriptor.ServiceDescriptor( + name='Predict', + full_name='Predict', + file=DESCRIPTOR, + index=0, + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_start=311, + serialized_end=395, + methods=[ + _descriptor.MethodDescriptor( + name='PayDay', + full_name='Predict.PayDay', + index=0, + containing_service=None, + input_type=_REQUESTPAY, + output_type=_REPLYPAY, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), + _descriptor.MethodDescriptor( + name='GiftDay', + full_name='Predict.GiftDay', + index=1, + containing_service=None, + input_type=_REQUESTGIFT, + output_type=_REPLYGIFT, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), +]) +_sym_db.RegisterServiceDescriptor(_PREDICT) + +DESCRIPTOR.services_by_name['Predict'] = _PREDICT + +# @@protoc_insertion_point(module_scope) diff --git a/predict_pb2_grpc.py b/predict_pb2_grpc.py new file mode 100644 index 0000000..78b837b --- /dev/null +++ b/predict_pb2_grpc.py @@ -0,0 +1,99 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import predict_pb2 as predict__pb2 + + +class PredictStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.PayDay = channel.unary_unary( + '/Predict/PayDay', + request_serializer=predict__pb2.RequestPay.SerializeToString, + response_deserializer=predict__pb2.ReplyPay.FromString, + ) + self.GiftDay = channel.unary_unary( + '/Predict/GiftDay', + request_serializer=predict__pb2.RequestGift.SerializeToString, + response_deserializer=predict__pb2.ReplyGift.FromString, + ) + + +class PredictServicer(object): + """Missing associated documentation comment in .proto file.""" + + def PayDay(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GiftDay(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_PredictServicer_to_server(servicer, server): + rpc_method_handlers = { + 'PayDay': grpc.unary_unary_rpc_method_handler( + servicer.PayDay, + request_deserializer=predict__pb2.RequestPay.FromString, + response_serializer=predict__pb2.ReplyPay.SerializeToString, + ), + 'GiftDay': grpc.unary_unary_rpc_method_handler( + servicer.GiftDay, + request_deserializer=predict__pb2.RequestGift.FromString, + response_serializer=predict__pb2.ReplyGift.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'Predict', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Predict(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def PayDay(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Predict/PayDay', + predict__pb2.RequestPay.SerializeToString, + predict__pb2.ReplyPay.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GiftDay(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/Predict/GiftDay', + predict__pb2.RequestGift.SerializeToString, + predict__pb2.ReplyGift.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/requirements.txt b/requirements.txt index 9a4d144..886ea59 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ tensorflow keras numpy pymysql +grpc_tools \ No newline at end of file diff --git a/train_gift.py b/train_gift.py index 21512bb..1cb085e 100644 --- a/train_gift.py +++ b/train_gift.py @@ -15,12 +15,12 @@ from data import load_gift_data if __name__ == "__main__": - x_train, y_train, tx_train, ty_train = load_gift_data() + x_train, y_train, tx_train, ty_train, input_shape = load_gift_data() model = Sequential() units = 400 - model.add(LSTM(units, activation='relu', input_shape=(4,1) )) + model.add(LSTM(units, activation='relu', input_shape=input_shape )) model.add(Dropout(0.2)) model.add(Dense(1)) @@ -28,7 +28,7 @@ if __name__ == "__main__": model.compile(loss='mse', optimizer='adam') - model.fit(x_train, y_train, batch_size=1, epochs=50) + model.fit(x_train, y_train, batch_size=128, epochs=1500) model.save("./predict_gift") p_data = model.predict(tx_train) diff --git a/train_pay.py b/train_pay.py index 93fa223..1c205c2 100644 --- a/train_pay.py +++ b/train_pay.py @@ -5,6 +5,7 @@ from keras.layers import Dense, Dropout, Embedding from keras.layers import InputLayer from keras.layers import LSTM from keras import backend +from keras.layers.recurrent import SimpleRNN import pymysql import pickle @@ -16,18 +17,20 @@ from data import load_pay_data if __name__ == "__main__": - x_train, y_train, tx_train, ty_train = load_pay_data(80) + x_train, y_train, tx_train, ty_train, input_shape = load_pay_data(80) model = Sequential() units = 500 - model.add(LSTM(units, activation='relu', input_shape=(3,1))) - model.add(Dropout(0.3)) + model.add(LSTM(units, activation='relu', dropout=0.1, input_shape=input_shape)) + + # model.add(SimpleRNN(units, activation='relu')) + # model.add(Dropout(0.1)) model.add(Dense(1)) model.summary() - model.compile(loss='mse', optimizer='adam') + model.compile(loss = 'mse', optimizer = 'adam') - model.fit(x_train, y_train, batch_size=1, epochs=50) + model.fit(x_train, y_train, batch_size=96, epochs=1200) model.save("./predict_pay") p_data = model.predict(tx_train)