完成预测模型的grpc服务

This commit is contained in:
eson 2021-03-25 17:36:09 +08:00
parent 36b30982f7
commit bb90138620
14 changed files with 553 additions and 47 deletions

4
__init__.py Normal file
View File

@ -0,0 +1,4 @@
import sys
import os
sys.path.append(".")

39
api/predict.proto Normal file
View File

@ -0,0 +1,39 @@
syntax = "proto3";
service Predict {
rpc PayDay(RequestPay) returns (ReplyPay) {}
rpc GiftDay(RequestGift) returns (ReplyGift) {}
}
// Request Request请求
message RequestPay {
// count, total_coin / last_total_coin, total_coin // 24
int32 Hour = 1;
int64 Coin = 2;
int64 YesterdayCoin = 3;
}
// RequestGift Request请求
message RequestGift {
// count, total_coin / last_total_coin, total_coin // 24
int32 Hour = 1;
int64 Coin = 2;
int64 YesterdayCoin = 3;
}
// ReplyHeader ReplyHeader
message ReplyHeader {
int32 Code = 1;
string Message = 2;
}
message ReplyPay {
ReplyHeader Header = 1;
int64 Result = 2;
}
message ReplyGift {
ReplyHeader Header = 1;
int64 Result = 2;
}

24
data.py
View File

@ -85,6 +85,7 @@ def get_collect():
def load_pay_data(textNum = 80):
collect = get_collect()
# TODO: 处理gift pay的波动关系
@ -101,6 +102,7 @@ def load_pay_data(textNum = 80):
for cur_v in collect_pay[1:]:
total_coin = 0
users = 0
last_total_coin = 0
for v2 in lastday_v:
@ -109,6 +111,7 @@ def load_pay_data(textNum = 80):
count = 0
for v1, v2 in zip(cur_v,lastday_v):
total_coin += v1[0] + v1[1]
users += v1[2]
# print(v1[3])
# last_total_coin += v2[0] + v2[1]
@ -116,7 +119,8 @@ def load_pay_data(textNum = 80):
# compare = float(total_coin - last_total_coin) / float(last_total_coin)
# print(compare)
x_train.append([count, total_coin, total_coin/last_total_coin])
# 时刻. 前一个小时 时刻. 当前支付总币数. 当前支付总币数 昨天币数
x_train.append([count ,total_coin / last_total_coin , total_coin])
count+=1
for i in range(count):
@ -124,7 +128,8 @@ def load_pay_data(textNum = 80):
lastday_v = cur_v
x_train = numpy.reshape(x_train, (len(x_train) , 3, 1))
input_shape = (len(x_train[0]), 1)
x_train = numpy.reshape(x_train, (len(x_train) , input_shape[0], input_shape[1]))
y_train = numpy.reshape(y_train, (len(y_train)))
# max_features = 1024
@ -134,7 +139,7 @@ def load_pay_data(textNum = 80):
# x_train = x_train[:len(x_train) - textNum]
# y_train = y_train[:len(y_train) - textNum]
return x_train, y_train, tx_train, ty_train
return x_train, y_train, tx_train, ty_train, input_shape
def load_gift_data(textNum = 80):
@ -159,7 +164,7 @@ def load_gift_data(textNum = 80):
last_total_coin += v2[0]
f = 20000000.0
count = 1
count = 0
for v1, v2 in zip(cur_v,lastday_v):
total_coin += v1[0]
# print(v1[3])
@ -170,16 +175,17 @@ def load_gift_data(textNum = 80):
# print(v2[3])
# compare = float(total_coin - last_total_coin) / float(last_total_coin)
# print(compare)
x_train.append([count, total_coin, total_coin / last_total_coin, users ])
# 参数 前一小个小时. 时刻. 当前金钱. 送礼人数
x_train.append([count, total_coin / last_total_coin, total_coin ])
count+=1
for i in range(count - 1):
for i in range(count):
y_train.append(total_coin)
lastday_v = cur_v
x_train = numpy.reshape(x_train, (len(x_train) , 4, 1))
input_shape = (len(x_train[0]), 1)
x_train = numpy.reshape(x_train, (len(x_train) , input_shape[0], input_shape[1]))
y_train = numpy.reshape(y_train, (len(y_train)))
# max_features = 1024
@ -189,4 +195,4 @@ def load_gift_data(textNum = 80):
# x_train = x_train[:len(x_train) - textNum]
# y_train = y_train[:len(y_train) - textNum]
return x_train, y_train, tx_train, ty_train
return x_train, y_train, tx_train, ty_train, input_shape

View File

@ -1 +0,0 @@

3
gen_proto3.sh Normal file
View File

@ -0,0 +1,3 @@
#! /bin/bash
PBPATH=./api
python -m grpc_tools.protoc -I$PBPATH --python_out=. --grpc_python_out=. $PBPATH/*.proto

20
grpc_client.py Normal file
View File

@ -0,0 +1,20 @@
import grpc
import logging
import predict_pb2, predict_pb2_grpc
def run():
option = [('grpc.keepalive_timeout_ms', 10000)]
with grpc.insecure_channel(target='localhost:50051', options=option) as channel:
stub = predict_pb2_grpc.PredictStub(channel)
response = stub.PayDay( predict_pb2.RequestPay(Hour=0,Coin=100,YesterdayCoin=100), timeout=10)
print(response, type(response.Header.Code))
response = stub.GiftDay(predict_pb2.RequestGift(Hour=0,Coin=100,YesterdayCoin=100), timeout=10)
print(response, type(response.Header.Code))
if __name__ == '__main__':
logging.basicConfig()
run()

39
grpc_server.py Normal file
View File

@ -0,0 +1,39 @@
import grpc
import logging
from concurrent import futures
import predict_pb2, predict_pb2_grpc
from predict_pb2 import *
from predict_pb2_grpc import *
import numpy
from keras.models import load_model
from data import load_pay_data, load_gift_data
pay_model = load_model("./predict_pay")
gift_model = load_model("./predict_gift")
class Predict(predict_pb2_grpc.PredictServicer):
def PayDay(self, request, context):
# print(request)
inputx = numpy.reshape([request.Hour, request.Coin / request.YesterdayCoin, request.Coin], (1, 3, 1))
predict_value = pay_model.predict(inputx)
return ReplyPay(Header=ReplyHeader(Code=0, Message=""), Result=int(predict_value[0][0]))
# return super().PayDay(request, context)
def GiftDay(self, request, context):
# print(request)
inputx = numpy.reshape([request.Hour, request.Coin / request.YesterdayCoin, request.Coin], (1, 3, 1))
predict_value = gift_model.predict(inputx)
return ReplyGift(Header=ReplyHeader(Code=0, Message=""), Result=int(predict_value[0][0]))
def server():
grpc_server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
predict_pb2_grpc.add_PredictServicer_to_server(Predict(), grpc_server)
grpc_server.add_insecure_port('[::]:50051')
grpc_server.start()
grpc_server.wait_for_termination()
if __name__ == '__main__':
logging.basicConfig()
server()

View File

@ -1,24 +0,0 @@
syntax = "proto3";
service Predict {
rpc PayDay(RequestPay) returns (Reply) {}
rpc GiftDay(RequestGift) returns (Reply) {}
}
// Request Request请求
message RequestPay {
// count, total_coin , (total_coin - last_total_coin) , v1[2] , v2[2]] // 24
int32 Hour = 1;
int64 Coin = 2;
int64 YesterdayCoin = 3;
}
// RequestGift Request请求
message RequestGift {
}
// Reply Reply
message Reply {
}

View File

@ -5,7 +5,7 @@ from data import load_pay_data, load_gift_data
import matplotlib.pyplot as plt
# x_train, y_train, tx_train, ty_train = load_pay_data(160)
# x_train, y_train, tx_train, ty_train, _ = load_pay_data(160)
# model = load_model("./predict_pay")
# p_data = model.predict(tx_train)
@ -16,7 +16,7 @@ import matplotlib.pyplot as plt
# print("测结果:", p_data[i][0], "测:", tx_train[i], "真实:", ty_train[i])
x_train, y_train, tx_train, ty_train = load_gift_data(160)
x_train, y_train, tx_train, ty_train, _ = load_gift_data(160)
model = load_model("./predict_gift")
p_data = model.predict(tx_train)
for i in range(len(p_data)):

317
predict_pb2.py Normal file
View File

@ -0,0 +1,317 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: predict.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='predict.proto',
package='',
syntax='proto3',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\rpredict.proto\"?\n\nRequestPay\x12\x0c\n\x04Hour\x18\x01 \x01(\x05\x12\x0c\n\x04\x43oin\x18\x02 \x01(\x03\x12\x15\n\rYesterdayCoin\x18\x03 \x01(\x03\"@\n\x0bRequestGift\x12\x0c\n\x04Hour\x18\x01 \x01(\x05\x12\x0c\n\x04\x43oin\x18\x02 \x01(\x03\x12\x15\n\rYesterdayCoin\x18\x03 \x01(\x03\",\n\x0bReplyHeader\x12\x0c\n\x04\x43ode\x18\x01 \x01(\x05\x12\x0f\n\x07Message\x18\x02 \x01(\t\"8\n\x08ReplyPay\x12\x1c\n\x06Header\x18\x01 \x01(\x0b\x32\x0c.ReplyHeader\x12\x0e\n\x06Result\x18\x02 \x01(\x03\"9\n\tReplyGift\x12\x1c\n\x06Header\x18\x01 \x01(\x0b\x32\x0c.ReplyHeader\x12\x0e\n\x06Result\x18\x02 \x01(\x03\x32T\n\x07Predict\x12\"\n\x06PayDay\x12\x0b.RequestPay\x1a\t.ReplyPay\"\x00\x12%\n\x07GiftDay\x12\x0c.RequestGift\x1a\n.ReplyGift\"\x00\x62\x06proto3'
)
_REQUESTPAY = _descriptor.Descriptor(
name='RequestPay',
full_name='RequestPay',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='Hour', full_name='RequestPay.Hour', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='Coin', full_name='RequestPay.Coin', index=1,
number=2, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='YesterdayCoin', full_name='RequestPay.YesterdayCoin', index=2,
number=3, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=17,
serialized_end=80,
)
_REQUESTGIFT = _descriptor.Descriptor(
name='RequestGift',
full_name='RequestGift',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='Hour', full_name='RequestGift.Hour', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='Coin', full_name='RequestGift.Coin', index=1,
number=2, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='YesterdayCoin', full_name='RequestGift.YesterdayCoin', index=2,
number=3, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=82,
serialized_end=146,
)
_REPLYHEADER = _descriptor.Descriptor(
name='ReplyHeader',
full_name='ReplyHeader',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='Code', full_name='ReplyHeader.Code', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='Message', full_name='ReplyHeader.Message', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=148,
serialized_end=192,
)
_REPLYPAY = _descriptor.Descriptor(
name='ReplyPay',
full_name='ReplyPay',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='Header', full_name='ReplyPay.Header', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='Result', full_name='ReplyPay.Result', index=1,
number=2, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=194,
serialized_end=250,
)
_REPLYGIFT = _descriptor.Descriptor(
name='ReplyGift',
full_name='ReplyGift',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='Header', full_name='ReplyGift.Header', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='Result', full_name='ReplyGift.Result', index=1,
number=2, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=252,
serialized_end=309,
)
_REPLYPAY.fields_by_name['Header'].message_type = _REPLYHEADER
_REPLYGIFT.fields_by_name['Header'].message_type = _REPLYHEADER
DESCRIPTOR.message_types_by_name['RequestPay'] = _REQUESTPAY
DESCRIPTOR.message_types_by_name['RequestGift'] = _REQUESTGIFT
DESCRIPTOR.message_types_by_name['ReplyHeader'] = _REPLYHEADER
DESCRIPTOR.message_types_by_name['ReplyPay'] = _REPLYPAY
DESCRIPTOR.message_types_by_name['ReplyGift'] = _REPLYGIFT
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
RequestPay = _reflection.GeneratedProtocolMessageType('RequestPay', (_message.Message,), {
'DESCRIPTOR' : _REQUESTPAY,
'__module__' : 'predict_pb2'
# @@protoc_insertion_point(class_scope:RequestPay)
})
_sym_db.RegisterMessage(RequestPay)
RequestGift = _reflection.GeneratedProtocolMessageType('RequestGift', (_message.Message,), {
'DESCRIPTOR' : _REQUESTGIFT,
'__module__' : 'predict_pb2'
# @@protoc_insertion_point(class_scope:RequestGift)
})
_sym_db.RegisterMessage(RequestGift)
ReplyHeader = _reflection.GeneratedProtocolMessageType('ReplyHeader', (_message.Message,), {
'DESCRIPTOR' : _REPLYHEADER,
'__module__' : 'predict_pb2'
# @@protoc_insertion_point(class_scope:ReplyHeader)
})
_sym_db.RegisterMessage(ReplyHeader)
ReplyPay = _reflection.GeneratedProtocolMessageType('ReplyPay', (_message.Message,), {
'DESCRIPTOR' : _REPLYPAY,
'__module__' : 'predict_pb2'
# @@protoc_insertion_point(class_scope:ReplyPay)
})
_sym_db.RegisterMessage(ReplyPay)
ReplyGift = _reflection.GeneratedProtocolMessageType('ReplyGift', (_message.Message,), {
'DESCRIPTOR' : _REPLYGIFT,
'__module__' : 'predict_pb2'
# @@protoc_insertion_point(class_scope:ReplyGift)
})
_sym_db.RegisterMessage(ReplyGift)
_PREDICT = _descriptor.ServiceDescriptor(
name='Predict',
full_name='Predict',
file=DESCRIPTOR,
index=0,
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_start=311,
serialized_end=395,
methods=[
_descriptor.MethodDescriptor(
name='PayDay',
full_name='Predict.PayDay',
index=0,
containing_service=None,
input_type=_REQUESTPAY,
output_type=_REPLYPAY,
serialized_options=None,
create_key=_descriptor._internal_create_key,
),
_descriptor.MethodDescriptor(
name='GiftDay',
full_name='Predict.GiftDay',
index=1,
containing_service=None,
input_type=_REQUESTGIFT,
output_type=_REPLYGIFT,
serialized_options=None,
create_key=_descriptor._internal_create_key,
),
])
_sym_db.RegisterServiceDescriptor(_PREDICT)
DESCRIPTOR.services_by_name['Predict'] = _PREDICT
# @@protoc_insertion_point(module_scope)

99
predict_pb2_grpc.py Normal file
View File

@ -0,0 +1,99 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import predict_pb2 as predict__pb2
class PredictStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.PayDay = channel.unary_unary(
'/Predict/PayDay',
request_serializer=predict__pb2.RequestPay.SerializeToString,
response_deserializer=predict__pb2.ReplyPay.FromString,
)
self.GiftDay = channel.unary_unary(
'/Predict/GiftDay',
request_serializer=predict__pb2.RequestGift.SerializeToString,
response_deserializer=predict__pb2.ReplyGift.FromString,
)
class PredictServicer(object):
"""Missing associated documentation comment in .proto file."""
def PayDay(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GiftDay(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_PredictServicer_to_server(servicer, server):
rpc_method_handlers = {
'PayDay': grpc.unary_unary_rpc_method_handler(
servicer.PayDay,
request_deserializer=predict__pb2.RequestPay.FromString,
response_serializer=predict__pb2.ReplyPay.SerializeToString,
),
'GiftDay': grpc.unary_unary_rpc_method_handler(
servicer.GiftDay,
request_deserializer=predict__pb2.RequestGift.FromString,
response_serializer=predict__pb2.ReplyGift.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'Predict', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class Predict(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def PayDay(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/Predict/PayDay',
predict__pb2.RequestPay.SerializeToString,
predict__pb2.ReplyPay.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def GiftDay(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/Predict/GiftDay',
predict__pb2.RequestGift.SerializeToString,
predict__pb2.ReplyGift.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

View File

@ -2,3 +2,4 @@ tensorflow
keras
numpy
pymysql
grpc_tools

View File

@ -15,12 +15,12 @@ from data import load_gift_data
if __name__ == "__main__":
x_train, y_train, tx_train, ty_train = load_gift_data()
x_train, y_train, tx_train, ty_train, input_shape = load_gift_data()
model = Sequential()
units = 400
model.add(LSTM(units, activation='relu', input_shape=(4,1) ))
model.add(LSTM(units, activation='relu', input_shape=input_shape ))
model.add(Dropout(0.2))
model.add(Dense(1))
@ -28,7 +28,7 @@ if __name__ == "__main__":
model.compile(loss='mse', optimizer='adam')
model.fit(x_train, y_train, batch_size=1, epochs=50)
model.fit(x_train, y_train, batch_size=128, epochs=1500)
model.save("./predict_gift")
p_data = model.predict(tx_train)

View File

@ -5,6 +5,7 @@ from keras.layers import Dense, Dropout, Embedding
from keras.layers import InputLayer
from keras.layers import LSTM
from keras import backend
from keras.layers.recurrent import SimpleRNN
import pymysql
import pickle
@ -16,18 +17,20 @@ from data import load_pay_data
if __name__ == "__main__":
x_train, y_train, tx_train, ty_train = load_pay_data(80)
x_train, y_train, tx_train, ty_train, input_shape = load_pay_data(80)
model = Sequential()
units = 500
model.add(LSTM(units, activation='relu', input_shape=(3,1)))
model.add(Dropout(0.3))
model.add(LSTM(units, activation='relu', dropout=0.1, input_shape=input_shape))
# model.add(SimpleRNN(units, activation='relu'))
# model.add(Dropout(0.1))
model.add(Dense(1))
model.summary()
model.compile(loss='mse', optimizer='adam')
model.compile(loss = 'mse', optimizer = 'adam')
model.fit(x_train, y_train, batch_size=1, epochs=50)
model.fit(x_train, y_train, batch_size=96, epochs=1200)
model.save("./predict_pay")
p_data = model.predict(tx_train)