完成预测模型的grpc服务

This commit is contained in:
eson 2021-03-25 17:36:09 +08:00
parent 36b30982f7
commit bb90138620
14 changed files with 553 additions and 47 deletions

4
__init__.py Normal file
View File

@ -0,0 +1,4 @@
import sys
import os
sys.path.append(".")

39
api/predict.proto Normal file
View File

@ -0,0 +1,39 @@
syntax = "proto3";
service Predict {
rpc PayDay(RequestPay) returns (ReplyPay) {}
rpc GiftDay(RequestGift) returns (ReplyGift) {}
}
// Request Request请求
message RequestPay {
// count, total_coin / last_total_coin, total_coin // 24
int32 Hour = 1;
int64 Coin = 2;
int64 YesterdayCoin = 3;
}
// RequestGift Request请求
message RequestGift {
// count, total_coin / last_total_coin, total_coin // 24
int32 Hour = 1;
int64 Coin = 2;
int64 YesterdayCoin = 3;
}
// ReplyHeader ReplyHeader
message ReplyHeader {
int32 Code = 1;
string Message = 2;
}
message ReplyPay {
ReplyHeader Header = 1;
int64 Result = 2;
}
message ReplyGift {
ReplyHeader Header = 1;
int64 Result = 2;
}

30
data.py
View File

@ -81,9 +81,10 @@ def get_collect():
finally: finally:
return collect return collect
def load_pay_data(textNum = 80): def load_pay_data(textNum = 80):
collect = get_collect() collect = get_collect()
@ -101,6 +102,7 @@ def load_pay_data(textNum = 80):
for cur_v in collect_pay[1:]: for cur_v in collect_pay[1:]:
total_coin = 0 total_coin = 0
users = 0
last_total_coin = 0 last_total_coin = 0
for v2 in lastday_v: for v2 in lastday_v:
@ -109,14 +111,16 @@ def load_pay_data(textNum = 80):
count = 0 count = 0
for v1, v2 in zip(cur_v,lastday_v): for v1, v2 in zip(cur_v,lastday_v):
total_coin += v1[0] + v1[1] total_coin += v1[0] + v1[1]
users += v1[2]
# print(v1[3]) # print(v1[3])
# last_total_coin += v2[0] + v2[1] # last_total_coin += v2[0] + v2[1]
# print(v2[3]) # print(v2[3])
# compare = float(total_coin - last_total_coin) / float(last_total_coin) # compare = float(total_coin - last_total_coin) / float(last_total_coin)
# print(compare) # print(compare)
x_train.append([count, total_coin, total_coin/last_total_coin]) # 时刻. 前一个小时 时刻. 当前支付总币数. 当前支付总币数 昨天币数
x_train.append([count ,total_coin / last_total_coin , total_coin])
count+=1 count+=1
for i in range(count): for i in range(count):
@ -124,7 +128,8 @@ def load_pay_data(textNum = 80):
lastday_v = cur_v lastday_v = cur_v
x_train = numpy.reshape(x_train, (len(x_train) , 3, 1)) input_shape = (len(x_train[0]), 1)
x_train = numpy.reshape(x_train, (len(x_train) , input_shape[0], input_shape[1]))
y_train = numpy.reshape(y_train, (len(y_train))) y_train = numpy.reshape(y_train, (len(y_train)))
# max_features = 1024 # max_features = 1024
@ -134,7 +139,7 @@ def load_pay_data(textNum = 80):
# x_train = x_train[:len(x_train) - textNum] # x_train = x_train[:len(x_train) - textNum]
# y_train = y_train[:len(y_train) - textNum] # y_train = y_train[:len(y_train) - textNum]
return x_train, y_train, tx_train, ty_train return x_train, y_train, tx_train, ty_train, input_shape
def load_gift_data(textNum = 80): def load_gift_data(textNum = 80):
@ -159,7 +164,7 @@ def load_gift_data(textNum = 80):
last_total_coin += v2[0] last_total_coin += v2[0]
f = 20000000.0 f = 20000000.0
count = 1 count = 0
for v1, v2 in zip(cur_v,lastday_v): for v1, v2 in zip(cur_v,lastday_v):
total_coin += v1[0] total_coin += v1[0]
# print(v1[3]) # print(v1[3])
@ -170,16 +175,17 @@ def load_gift_data(textNum = 80):
# print(v2[3]) # print(v2[3])
# compare = float(total_coin - last_total_coin) / float(last_total_coin) # compare = float(total_coin - last_total_coin) / float(last_total_coin)
# print(compare) # print(compare)
# 参数 前一小个小时. 时刻. 当前金钱. 送礼人数
x_train.append([count, total_coin, total_coin / last_total_coin, users ]) x_train.append([count, total_coin / last_total_coin, total_coin ])
count+=1 count+=1
for i in range(count - 1): for i in range(count):
y_train.append(total_coin) y_train.append(total_coin)
lastday_v = cur_v lastday_v = cur_v
x_train = numpy.reshape(x_train, (len(x_train) , 4, 1)) input_shape = (len(x_train[0]), 1)
x_train = numpy.reshape(x_train, (len(x_train) , input_shape[0], input_shape[1]))
y_train = numpy.reshape(y_train, (len(y_train))) y_train = numpy.reshape(y_train, (len(y_train)))
# max_features = 1024 # max_features = 1024
@ -189,4 +195,4 @@ def load_gift_data(textNum = 80):
# x_train = x_train[:len(x_train) - textNum] # x_train = x_train[:len(x_train) - textNum]
# y_train = y_train[:len(y_train) - textNum] # y_train = y_train[:len(y_train) - textNum]
return x_train, y_train, tx_train, ty_train return x_train, y_train, tx_train, ty_train, input_shape

View File

@ -1 +0,0 @@

3
gen_proto3.sh Normal file
View File

@ -0,0 +1,3 @@
#! /bin/bash
PBPATH=./api
python -m grpc_tools.protoc -I$PBPATH --python_out=. --grpc_python_out=. $PBPATH/*.proto

20
grpc_client.py Normal file
View File

@ -0,0 +1,20 @@
import grpc
import logging
import predict_pb2, predict_pb2_grpc
def run():
option = [('grpc.keepalive_timeout_ms', 10000)]
with grpc.insecure_channel(target='localhost:50051', options=option) as channel:
stub = predict_pb2_grpc.PredictStub(channel)
response = stub.PayDay( predict_pb2.RequestPay(Hour=0,Coin=100,YesterdayCoin=100), timeout=10)
print(response, type(response.Header.Code))
response = stub.GiftDay(predict_pb2.RequestGift(Hour=0,Coin=100,YesterdayCoin=100), timeout=10)
print(response, type(response.Header.Code))
if __name__ == '__main__':
logging.basicConfig()
run()

39
grpc_server.py Normal file
View File

@ -0,0 +1,39 @@
import grpc
import logging
from concurrent import futures
import predict_pb2, predict_pb2_grpc
from predict_pb2 import *
from predict_pb2_grpc import *
import numpy
from keras.models import load_model
from data import load_pay_data, load_gift_data
pay_model = load_model("./predict_pay")
gift_model = load_model("./predict_gift")
class Predict(predict_pb2_grpc.PredictServicer):
def PayDay(self, request, context):
# print(request)
inputx = numpy.reshape([request.Hour, request.Coin / request.YesterdayCoin, request.Coin], (1, 3, 1))
predict_value = pay_model.predict(inputx)
return ReplyPay(Header=ReplyHeader(Code=0, Message=""), Result=int(predict_value[0][0]))
# return super().PayDay(request, context)
def GiftDay(self, request, context):
# print(request)
inputx = numpy.reshape([request.Hour, request.Coin / request.YesterdayCoin, request.Coin], (1, 3, 1))
predict_value = gift_model.predict(inputx)
return ReplyGift(Header=ReplyHeader(Code=0, Message=""), Result=int(predict_value[0][0]))
def server():
grpc_server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
predict_pb2_grpc.add_PredictServicer_to_server(Predict(), grpc_server)
grpc_server.add_insecure_port('[::]:50051')
grpc_server.start()
grpc_server.wait_for_termination()
if __name__ == '__main__':
logging.basicConfig()
server()

View File

@ -1,24 +0,0 @@
syntax = "proto3";
service Predict {
rpc PayDay(RequestPay) returns (Reply) {}
rpc GiftDay(RequestGift) returns (Reply) {}
}
// Request Request请求
message RequestPay {
// count, total_coin , (total_coin - last_total_coin) , v1[2] , v2[2]] // 24
int32 Hour = 1;
int64 Coin = 2;
int64 YesterdayCoin = 3;
}
// RequestGift Request请求
message RequestGift {
}
// Reply Reply
message Reply {
}

View File

@ -5,7 +5,7 @@ from data import load_pay_data, load_gift_data
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
# x_train, y_train, tx_train, ty_train = load_pay_data(160) # x_train, y_train, tx_train, ty_train, _ = load_pay_data(160)
# model = load_model("./predict_pay") # model = load_model("./predict_pay")
# p_data = model.predict(tx_train) # p_data = model.predict(tx_train)
@ -16,7 +16,7 @@ import matplotlib.pyplot as plt
# print("测结果:", p_data[i][0], "测:", tx_train[i], "真实:", ty_train[i]) # print("测结果:", p_data[i][0], "测:", tx_train[i], "真实:", ty_train[i])
x_train, y_train, tx_train, ty_train = load_gift_data(160) x_train, y_train, tx_train, ty_train, _ = load_gift_data(160)
model = load_model("./predict_gift") model = load_model("./predict_gift")
p_data = model.predict(tx_train) p_data = model.predict(tx_train)
for i in range(len(p_data)): for i in range(len(p_data)):

317
predict_pb2.py Normal file
View File

@ -0,0 +1,317 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: predict.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='predict.proto',
package='',
syntax='proto3',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\rpredict.proto\"?\n\nRequestPay\x12\x0c\n\x04Hour\x18\x01 \x01(\x05\x12\x0c\n\x04\x43oin\x18\x02 \x01(\x03\x12\x15\n\rYesterdayCoin\x18\x03 \x01(\x03\"@\n\x0bRequestGift\x12\x0c\n\x04Hour\x18\x01 \x01(\x05\x12\x0c\n\x04\x43oin\x18\x02 \x01(\x03\x12\x15\n\rYesterdayCoin\x18\x03 \x01(\x03\",\n\x0bReplyHeader\x12\x0c\n\x04\x43ode\x18\x01 \x01(\x05\x12\x0f\n\x07Message\x18\x02 \x01(\t\"8\n\x08ReplyPay\x12\x1c\n\x06Header\x18\x01 \x01(\x0b\x32\x0c.ReplyHeader\x12\x0e\n\x06Result\x18\x02 \x01(\x03\"9\n\tReplyGift\x12\x1c\n\x06Header\x18\x01 \x01(\x0b\x32\x0c.ReplyHeader\x12\x0e\n\x06Result\x18\x02 \x01(\x03\x32T\n\x07Predict\x12\"\n\x06PayDay\x12\x0b.RequestPay\x1a\t.ReplyPay\"\x00\x12%\n\x07GiftDay\x12\x0c.RequestGift\x1a\n.ReplyGift\"\x00\x62\x06proto3'
)
_REQUESTPAY = _descriptor.Descriptor(
name='RequestPay',
full_name='RequestPay',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='Hour', full_name='RequestPay.Hour', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='Coin', full_name='RequestPay.Coin', index=1,
number=2, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='YesterdayCoin', full_name='RequestPay.YesterdayCoin', index=2,
number=3, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=17,
serialized_end=80,
)
_REQUESTGIFT = _descriptor.Descriptor(
name='RequestGift',
full_name='RequestGift',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='Hour', full_name='RequestGift.Hour', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='Coin', full_name='RequestGift.Coin', index=1,
number=2, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='YesterdayCoin', full_name='RequestGift.YesterdayCoin', index=2,
number=3, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=82,
serialized_end=146,
)
_REPLYHEADER = _descriptor.Descriptor(
name='ReplyHeader',
full_name='ReplyHeader',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='Code', full_name='ReplyHeader.Code', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='Message', full_name='ReplyHeader.Message', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=148,
serialized_end=192,
)
_REPLYPAY = _descriptor.Descriptor(
name='ReplyPay',
full_name='ReplyPay',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='Header', full_name='ReplyPay.Header', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='Result', full_name='ReplyPay.Result', index=1,
number=2, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=194,
serialized_end=250,
)
_REPLYGIFT = _descriptor.Descriptor(
name='ReplyGift',
full_name='ReplyGift',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='Header', full_name='ReplyGift.Header', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='Result', full_name='ReplyGift.Result', index=1,
number=2, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=252,
serialized_end=309,
)
_REPLYPAY.fields_by_name['Header'].message_type = _REPLYHEADER
_REPLYGIFT.fields_by_name['Header'].message_type = _REPLYHEADER
DESCRIPTOR.message_types_by_name['RequestPay'] = _REQUESTPAY
DESCRIPTOR.message_types_by_name['RequestGift'] = _REQUESTGIFT
DESCRIPTOR.message_types_by_name['ReplyHeader'] = _REPLYHEADER
DESCRIPTOR.message_types_by_name['ReplyPay'] = _REPLYPAY
DESCRIPTOR.message_types_by_name['ReplyGift'] = _REPLYGIFT
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
RequestPay = _reflection.GeneratedProtocolMessageType('RequestPay', (_message.Message,), {
'DESCRIPTOR' : _REQUESTPAY,
'__module__' : 'predict_pb2'
# @@protoc_insertion_point(class_scope:RequestPay)
})
_sym_db.RegisterMessage(RequestPay)
RequestGift = _reflection.GeneratedProtocolMessageType('RequestGift', (_message.Message,), {
'DESCRIPTOR' : _REQUESTGIFT,
'__module__' : 'predict_pb2'
# @@protoc_insertion_point(class_scope:RequestGift)
})
_sym_db.RegisterMessage(RequestGift)
ReplyHeader = _reflection.GeneratedProtocolMessageType('ReplyHeader', (_message.Message,), {
'DESCRIPTOR' : _REPLYHEADER,
'__module__' : 'predict_pb2'
# @@protoc_insertion_point(class_scope:ReplyHeader)
})
_sym_db.RegisterMessage(ReplyHeader)
ReplyPay = _reflection.GeneratedProtocolMessageType('ReplyPay', (_message.Message,), {
'DESCRIPTOR' : _REPLYPAY,
'__module__' : 'predict_pb2'
# @@protoc_insertion_point(class_scope:ReplyPay)
})
_sym_db.RegisterMessage(ReplyPay)
ReplyGift = _reflection.GeneratedProtocolMessageType('ReplyGift', (_message.Message,), {
'DESCRIPTOR' : _REPLYGIFT,
'__module__' : 'predict_pb2'
# @@protoc_insertion_point(class_scope:ReplyGift)
})
_sym_db.RegisterMessage(ReplyGift)
_PREDICT = _descriptor.ServiceDescriptor(
name='Predict',
full_name='Predict',
file=DESCRIPTOR,
index=0,
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_start=311,
serialized_end=395,
methods=[
_descriptor.MethodDescriptor(
name='PayDay',
full_name='Predict.PayDay',
index=0,
containing_service=None,
input_type=_REQUESTPAY,
output_type=_REPLYPAY,
serialized_options=None,
create_key=_descriptor._internal_create_key,
),
_descriptor.MethodDescriptor(
name='GiftDay',
full_name='Predict.GiftDay',
index=1,
containing_service=None,
input_type=_REQUESTGIFT,
output_type=_REPLYGIFT,
serialized_options=None,
create_key=_descriptor._internal_create_key,
),
])
_sym_db.RegisterServiceDescriptor(_PREDICT)
DESCRIPTOR.services_by_name['Predict'] = _PREDICT
# @@protoc_insertion_point(module_scope)

99
predict_pb2_grpc.py Normal file
View File

@ -0,0 +1,99 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import predict_pb2 as predict__pb2
class PredictStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.PayDay = channel.unary_unary(
'/Predict/PayDay',
request_serializer=predict__pb2.RequestPay.SerializeToString,
response_deserializer=predict__pb2.ReplyPay.FromString,
)
self.GiftDay = channel.unary_unary(
'/Predict/GiftDay',
request_serializer=predict__pb2.RequestGift.SerializeToString,
response_deserializer=predict__pb2.ReplyGift.FromString,
)
class PredictServicer(object):
"""Missing associated documentation comment in .proto file."""
def PayDay(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GiftDay(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_PredictServicer_to_server(servicer, server):
rpc_method_handlers = {
'PayDay': grpc.unary_unary_rpc_method_handler(
servicer.PayDay,
request_deserializer=predict__pb2.RequestPay.FromString,
response_serializer=predict__pb2.ReplyPay.SerializeToString,
),
'GiftDay': grpc.unary_unary_rpc_method_handler(
servicer.GiftDay,
request_deserializer=predict__pb2.RequestGift.FromString,
response_serializer=predict__pb2.ReplyGift.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'Predict', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class Predict(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def PayDay(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/Predict/PayDay',
predict__pb2.RequestPay.SerializeToString,
predict__pb2.ReplyPay.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def GiftDay(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/Predict/GiftDay',
predict__pb2.RequestGift.SerializeToString,
predict__pb2.ReplyGift.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

View File

@ -2,3 +2,4 @@ tensorflow
keras keras
numpy numpy
pymysql pymysql
grpc_tools

View File

@ -15,12 +15,12 @@ from data import load_gift_data
if __name__ == "__main__": if __name__ == "__main__":
x_train, y_train, tx_train, ty_train = load_gift_data() x_train, y_train, tx_train, ty_train, input_shape = load_gift_data()
model = Sequential() model = Sequential()
units = 400 units = 400
model.add(LSTM(units, activation='relu', input_shape=(4,1) )) model.add(LSTM(units, activation='relu', input_shape=input_shape ))
model.add(Dropout(0.2)) model.add(Dropout(0.2))
model.add(Dense(1)) model.add(Dense(1))
@ -28,7 +28,7 @@ if __name__ == "__main__":
model.compile(loss='mse', optimizer='adam') model.compile(loss='mse', optimizer='adam')
model.fit(x_train, y_train, batch_size=1, epochs=50) model.fit(x_train, y_train, batch_size=128, epochs=1500)
model.save("./predict_gift") model.save("./predict_gift")
p_data = model.predict(tx_train) p_data = model.predict(tx_train)

View File

@ -5,6 +5,7 @@ from keras.layers import Dense, Dropout, Embedding
from keras.layers import InputLayer from keras.layers import InputLayer
from keras.layers import LSTM from keras.layers import LSTM
from keras import backend from keras import backend
from keras.layers.recurrent import SimpleRNN
import pymysql import pymysql
import pickle import pickle
@ -16,18 +17,20 @@ from data import load_pay_data
if __name__ == "__main__": if __name__ == "__main__":
x_train, y_train, tx_train, ty_train = load_pay_data(80) x_train, y_train, tx_train, ty_train, input_shape = load_pay_data(80)
model = Sequential() model = Sequential()
units = 500 units = 500
model.add(LSTM(units, activation='relu', input_shape=(3,1))) model.add(LSTM(units, activation='relu', dropout=0.1, input_shape=input_shape))
model.add(Dropout(0.3))
# model.add(SimpleRNN(units, activation='relu'))
# model.add(Dropout(0.1))
model.add(Dense(1)) model.add(Dense(1))
model.summary() model.summary()
model.compile(loss='mse', optimizer='adam') model.compile(loss = 'mse', optimizer = 'adam')
model.fit(x_train, y_train, batch_size=1, epochs=50) model.fit(x_train, y_train, batch_size=96, epochs=1200)
model.save("./predict_pay") model.save("./predict_pay")
p_data = model.predict(tx_train) p_data = model.predict(tx_train)