完成预测模型的grpc服务
This commit is contained in:
parent
36b30982f7
commit
bb90138620
4
__init__.py
Normal file
4
__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.append(".")
|
39
api/predict.proto
Normal file
39
api/predict.proto
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
service Predict {
|
||||||
|
rpc PayDay(RequestPay) returns (ReplyPay) {}
|
||||||
|
rpc GiftDay(RequestGift) returns (ReplyGift) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request 统一Request请求
|
||||||
|
message RequestPay {
|
||||||
|
// count, total_coin / last_total_coin, total_coin // 24小时
|
||||||
|
int32 Hour = 1;
|
||||||
|
int64 Coin = 2;
|
||||||
|
int64 YesterdayCoin = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestGift 统一Request请求
|
||||||
|
message RequestGift {
|
||||||
|
// count, total_coin / last_total_coin, total_coin // 24小时
|
||||||
|
int32 Hour = 1;
|
||||||
|
int64 Coin = 2;
|
||||||
|
int64 YesterdayCoin = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplyHeader 统一ReplyHeader 响应
|
||||||
|
message ReplyHeader {
|
||||||
|
int32 Code = 1;
|
||||||
|
string Message = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ReplyPay {
|
||||||
|
ReplyHeader Header = 1;
|
||||||
|
int64 Result = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message ReplyGift {
|
||||||
|
ReplyHeader Header = 1;
|
||||||
|
int64 Result = 2;
|
||||||
|
}
|
30
data.py
30
data.py
|
@ -81,9 +81,10 @@ def get_collect():
|
||||||
finally:
|
finally:
|
||||||
return collect
|
return collect
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_pay_data(textNum = 80):
|
def load_pay_data(textNum = 80):
|
||||||
|
|
||||||
|
|
||||||
collect = get_collect()
|
collect = get_collect()
|
||||||
|
|
||||||
|
@ -101,6 +102,7 @@ def load_pay_data(textNum = 80):
|
||||||
for cur_v in collect_pay[1:]:
|
for cur_v in collect_pay[1:]:
|
||||||
|
|
||||||
total_coin = 0
|
total_coin = 0
|
||||||
|
users = 0
|
||||||
last_total_coin = 0
|
last_total_coin = 0
|
||||||
|
|
||||||
for v2 in lastday_v:
|
for v2 in lastday_v:
|
||||||
|
@ -109,14 +111,16 @@ def load_pay_data(textNum = 80):
|
||||||
count = 0
|
count = 0
|
||||||
for v1, v2 in zip(cur_v,lastday_v):
|
for v1, v2 in zip(cur_v,lastday_v):
|
||||||
total_coin += v1[0] + v1[1]
|
total_coin += v1[0] + v1[1]
|
||||||
|
users += v1[2]
|
||||||
# print(v1[3])
|
# print(v1[3])
|
||||||
|
|
||||||
# last_total_coin += v2[0] + v2[1]
|
# last_total_coin += v2[0] + v2[1]
|
||||||
# print(v2[3])
|
# print(v2[3])
|
||||||
# compare = float(total_coin - last_total_coin) / float(last_total_coin)
|
# compare = float(total_coin - last_total_coin) / float(last_total_coin)
|
||||||
# print(compare)
|
# print(compare)
|
||||||
|
|
||||||
x_train.append([count, total_coin, total_coin/last_total_coin])
|
# 时刻. 前一个小时 时刻. 当前支付总币数. 当前支付总币数 昨天币数
|
||||||
|
x_train.append([count ,total_coin / last_total_coin , total_coin])
|
||||||
count+=1
|
count+=1
|
||||||
|
|
||||||
for i in range(count):
|
for i in range(count):
|
||||||
|
@ -124,7 +128,8 @@ def load_pay_data(textNum = 80):
|
||||||
|
|
||||||
lastday_v = cur_v
|
lastday_v = cur_v
|
||||||
|
|
||||||
x_train = numpy.reshape(x_train, (len(x_train) , 3, 1))
|
input_shape = (len(x_train[0]), 1)
|
||||||
|
x_train = numpy.reshape(x_train, (len(x_train) , input_shape[0], input_shape[1]))
|
||||||
y_train = numpy.reshape(y_train, (len(y_train)))
|
y_train = numpy.reshape(y_train, (len(y_train)))
|
||||||
# max_features = 1024
|
# max_features = 1024
|
||||||
|
|
||||||
|
@ -134,7 +139,7 @@ def load_pay_data(textNum = 80):
|
||||||
# x_train = x_train[:len(x_train) - textNum]
|
# x_train = x_train[:len(x_train) - textNum]
|
||||||
# y_train = y_train[:len(y_train) - textNum]
|
# y_train = y_train[:len(y_train) - textNum]
|
||||||
|
|
||||||
return x_train, y_train, tx_train, ty_train
|
return x_train, y_train, tx_train, ty_train, input_shape
|
||||||
|
|
||||||
def load_gift_data(textNum = 80):
|
def load_gift_data(textNum = 80):
|
||||||
|
|
||||||
|
@ -159,7 +164,7 @@ def load_gift_data(textNum = 80):
|
||||||
last_total_coin += v2[0]
|
last_total_coin += v2[0]
|
||||||
|
|
||||||
f = 20000000.0
|
f = 20000000.0
|
||||||
count = 1
|
count = 0
|
||||||
for v1, v2 in zip(cur_v,lastday_v):
|
for v1, v2 in zip(cur_v,lastday_v):
|
||||||
total_coin += v1[0]
|
total_coin += v1[0]
|
||||||
# print(v1[3])
|
# print(v1[3])
|
||||||
|
@ -170,16 +175,17 @@ def load_gift_data(textNum = 80):
|
||||||
# print(v2[3])
|
# print(v2[3])
|
||||||
# compare = float(total_coin - last_total_coin) / float(last_total_coin)
|
# compare = float(total_coin - last_total_coin) / float(last_total_coin)
|
||||||
# print(compare)
|
# print(compare)
|
||||||
|
# 参数 前一小个小时. 时刻. 当前金钱. 送礼人数
|
||||||
x_train.append([count, total_coin, total_coin / last_total_coin, users ])
|
x_train.append([count, total_coin / last_total_coin, total_coin ])
|
||||||
count+=1
|
count+=1
|
||||||
|
|
||||||
for i in range(count - 1):
|
for i in range(count):
|
||||||
y_train.append(total_coin)
|
y_train.append(total_coin)
|
||||||
|
|
||||||
lastday_v = cur_v
|
lastday_v = cur_v
|
||||||
|
|
||||||
x_train = numpy.reshape(x_train, (len(x_train) , 4, 1))
|
input_shape = (len(x_train[0]), 1)
|
||||||
|
x_train = numpy.reshape(x_train, (len(x_train) , input_shape[0], input_shape[1]))
|
||||||
y_train = numpy.reshape(y_train, (len(y_train)))
|
y_train = numpy.reshape(y_train, (len(y_train)))
|
||||||
# max_features = 1024
|
# max_features = 1024
|
||||||
|
|
||||||
|
@ -189,4 +195,4 @@ def load_gift_data(textNum = 80):
|
||||||
# x_train = x_train[:len(x_train) - textNum]
|
# x_train = x_train[:len(x_train) - textNum]
|
||||||
# y_train = y_train[:len(y_train) - textNum]
|
# y_train = y_train[:len(y_train) - textNum]
|
||||||
|
|
||||||
return x_train, y_train, tx_train, ty_train
|
return x_train, y_train, tx_train, ty_train, input_shape
|
|
@ -1 +0,0 @@
|
||||||
|
|
3
gen_proto3.sh
Normal file
3
gen_proto3.sh
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
#! /bin/bash
|
||||||
|
PBPATH=./api
|
||||||
|
python -m grpc_tools.protoc -I$PBPATH --python_out=. --grpc_python_out=. $PBPATH/*.proto
|
20
grpc_client.py
Normal file
20
grpc_client.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
import grpc
|
||||||
|
import logging
|
||||||
|
import predict_pb2, predict_pb2_grpc
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
option = [('grpc.keepalive_timeout_ms', 10000)]
|
||||||
|
with grpc.insecure_channel(target='localhost:50051', options=option) as channel:
|
||||||
|
stub = predict_pb2_grpc.PredictStub(channel)
|
||||||
|
|
||||||
|
response = stub.PayDay( predict_pb2.RequestPay(Hour=0,Coin=100,YesterdayCoin=100), timeout=10)
|
||||||
|
print(response, type(response.Header.Code))
|
||||||
|
|
||||||
|
response = stub.GiftDay(predict_pb2.RequestGift(Hour=0,Coin=100,YesterdayCoin=100), timeout=10)
|
||||||
|
print(response, type(response.Header.Code))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
logging.basicConfig()
|
||||||
|
run()
|
39
grpc_server.py
Normal file
39
grpc_server.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
import logging
|
||||||
|
from concurrent import futures
|
||||||
|
|
||||||
|
import predict_pb2, predict_pb2_grpc
|
||||||
|
from predict_pb2 import *
|
||||||
|
from predict_pb2_grpc import *
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
from keras.models import load_model
|
||||||
|
from data import load_pay_data, load_gift_data
|
||||||
|
pay_model = load_model("./predict_pay")
|
||||||
|
gift_model = load_model("./predict_gift")
|
||||||
|
|
||||||
|
class Predict(predict_pb2_grpc.PredictServicer):
|
||||||
|
def PayDay(self, request, context):
|
||||||
|
# print(request)
|
||||||
|
inputx = numpy.reshape([request.Hour, request.Coin / request.YesterdayCoin, request.Coin], (1, 3, 1))
|
||||||
|
predict_value = pay_model.predict(inputx)
|
||||||
|
return ReplyPay(Header=ReplyHeader(Code=0, Message=""), Result=int(predict_value[0][0]))
|
||||||
|
# return super().PayDay(request, context)
|
||||||
|
def GiftDay(self, request, context):
|
||||||
|
# print(request)
|
||||||
|
inputx = numpy.reshape([request.Hour, request.Coin / request.YesterdayCoin, request.Coin], (1, 3, 1))
|
||||||
|
predict_value = gift_model.predict(inputx)
|
||||||
|
return ReplyGift(Header=ReplyHeader(Code=0, Message=""), Result=int(predict_value[0][0]))
|
||||||
|
|
||||||
|
def server():
|
||||||
|
grpc_server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
|
||||||
|
predict_pb2_grpc.add_PredictServicer_to_server(Predict(), grpc_server)
|
||||||
|
grpc_server.add_insecure_port('[::]:50051')
|
||||||
|
grpc_server.start()
|
||||||
|
grpc_server.wait_for_termination()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
logging.basicConfig()
|
||||||
|
server()
|
|
@ -1,24 +0,0 @@
|
||||||
syntax = "proto3";
|
|
||||||
|
|
||||||
service Predict {
|
|
||||||
rpc PayDay(RequestPay) returns (Reply) {}
|
|
||||||
rpc GiftDay(RequestGift) returns (Reply) {}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Request 统一Request请求
|
|
||||||
message RequestPay {
|
|
||||||
// count, total_coin , (total_coin - last_total_coin) , v1[2] , v2[2]] // 24小时
|
|
||||||
int32 Hour = 1;
|
|
||||||
int64 Coin = 2;
|
|
||||||
int64 YesterdayCoin = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestGift 统一Request请求
|
|
||||||
message RequestGift {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reply 统一Reply 响应
|
|
||||||
message Reply {
|
|
||||||
|
|
||||||
}
|
|
|
@ -5,7 +5,7 @@ from data import load_pay_data, load_gift_data
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
# x_train, y_train, tx_train, ty_train = load_pay_data(160)
|
# x_train, y_train, tx_train, ty_train, _ = load_pay_data(160)
|
||||||
# model = load_model("./predict_pay")
|
# model = load_model("./predict_pay")
|
||||||
|
|
||||||
# p_data = model.predict(tx_train)
|
# p_data = model.predict(tx_train)
|
||||||
|
@ -16,7 +16,7 @@ import matplotlib.pyplot as plt
|
||||||
# print("测结果:", p_data[i][0], "测:", tx_train[i], "真实:", ty_train[i])
|
# print("测结果:", p_data[i][0], "测:", tx_train[i], "真实:", ty_train[i])
|
||||||
|
|
||||||
|
|
||||||
x_train, y_train, tx_train, ty_train = load_gift_data(160)
|
x_train, y_train, tx_train, ty_train, _ = load_gift_data(160)
|
||||||
model = load_model("./predict_gift")
|
model = load_model("./predict_gift")
|
||||||
p_data = model.predict(tx_train)
|
p_data = model.predict(tx_train)
|
||||||
for i in range(len(p_data)):
|
for i in range(len(p_data)):
|
||||||
|
|
317
predict_pb2.py
Normal file
317
predict_pb2.py
Normal file
|
@ -0,0 +1,317 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
|
# source: predict.proto
|
||||||
|
"""Generated protocol buffer code."""
|
||||||
|
from google.protobuf import descriptor as _descriptor
|
||||||
|
from google.protobuf import message as _message
|
||||||
|
from google.protobuf import reflection as _reflection
|
||||||
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
DESCRIPTOR = _descriptor.FileDescriptor(
|
||||||
|
name='predict.proto',
|
||||||
|
package='',
|
||||||
|
syntax='proto3',
|
||||||
|
serialized_options=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
serialized_pb=b'\n\rpredict.proto\"?\n\nRequestPay\x12\x0c\n\x04Hour\x18\x01 \x01(\x05\x12\x0c\n\x04\x43oin\x18\x02 \x01(\x03\x12\x15\n\rYesterdayCoin\x18\x03 \x01(\x03\"@\n\x0bRequestGift\x12\x0c\n\x04Hour\x18\x01 \x01(\x05\x12\x0c\n\x04\x43oin\x18\x02 \x01(\x03\x12\x15\n\rYesterdayCoin\x18\x03 \x01(\x03\",\n\x0bReplyHeader\x12\x0c\n\x04\x43ode\x18\x01 \x01(\x05\x12\x0f\n\x07Message\x18\x02 \x01(\t\"8\n\x08ReplyPay\x12\x1c\n\x06Header\x18\x01 \x01(\x0b\x32\x0c.ReplyHeader\x12\x0e\n\x06Result\x18\x02 \x01(\x03\"9\n\tReplyGift\x12\x1c\n\x06Header\x18\x01 \x01(\x0b\x32\x0c.ReplyHeader\x12\x0e\n\x06Result\x18\x02 \x01(\x03\x32T\n\x07Predict\x12\"\n\x06PayDay\x12\x0b.RequestPay\x1a\t.ReplyPay\"\x00\x12%\n\x07GiftDay\x12\x0c.RequestGift\x1a\n.ReplyGift\"\x00\x62\x06proto3'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
_REQUESTPAY = _descriptor.Descriptor(
|
||||||
|
name='RequestPay',
|
||||||
|
full_name='RequestPay',
|
||||||
|
filename=None,
|
||||||
|
file=DESCRIPTOR,
|
||||||
|
containing_type=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
fields=[
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='Hour', full_name='RequestPay.Hour', index=0,
|
||||||
|
number=1, type=5, cpp_type=1, label=1,
|
||||||
|
has_default_value=False, default_value=0,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='Coin', full_name='RequestPay.Coin', index=1,
|
||||||
|
number=2, type=3, cpp_type=2, label=1,
|
||||||
|
has_default_value=False, default_value=0,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='YesterdayCoin', full_name='RequestPay.YesterdayCoin', index=2,
|
||||||
|
number=3, type=3, cpp_type=2, label=1,
|
||||||
|
has_default_value=False, default_value=0,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
],
|
||||||
|
extensions=[
|
||||||
|
],
|
||||||
|
nested_types=[],
|
||||||
|
enum_types=[
|
||||||
|
],
|
||||||
|
serialized_options=None,
|
||||||
|
is_extendable=False,
|
||||||
|
syntax='proto3',
|
||||||
|
extension_ranges=[],
|
||||||
|
oneofs=[
|
||||||
|
],
|
||||||
|
serialized_start=17,
|
||||||
|
serialized_end=80,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_REQUESTGIFT = _descriptor.Descriptor(
|
||||||
|
name='RequestGift',
|
||||||
|
full_name='RequestGift',
|
||||||
|
filename=None,
|
||||||
|
file=DESCRIPTOR,
|
||||||
|
containing_type=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
fields=[
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='Hour', full_name='RequestGift.Hour', index=0,
|
||||||
|
number=1, type=5, cpp_type=1, label=1,
|
||||||
|
has_default_value=False, default_value=0,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='Coin', full_name='RequestGift.Coin', index=1,
|
||||||
|
number=2, type=3, cpp_type=2, label=1,
|
||||||
|
has_default_value=False, default_value=0,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='YesterdayCoin', full_name='RequestGift.YesterdayCoin', index=2,
|
||||||
|
number=3, type=3, cpp_type=2, label=1,
|
||||||
|
has_default_value=False, default_value=0,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
],
|
||||||
|
extensions=[
|
||||||
|
],
|
||||||
|
nested_types=[],
|
||||||
|
enum_types=[
|
||||||
|
],
|
||||||
|
serialized_options=None,
|
||||||
|
is_extendable=False,
|
||||||
|
syntax='proto3',
|
||||||
|
extension_ranges=[],
|
||||||
|
oneofs=[
|
||||||
|
],
|
||||||
|
serialized_start=82,
|
||||||
|
serialized_end=146,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_REPLYHEADER = _descriptor.Descriptor(
|
||||||
|
name='ReplyHeader',
|
||||||
|
full_name='ReplyHeader',
|
||||||
|
filename=None,
|
||||||
|
file=DESCRIPTOR,
|
||||||
|
containing_type=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
fields=[
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='Code', full_name='ReplyHeader.Code', index=0,
|
||||||
|
number=1, type=5, cpp_type=1, label=1,
|
||||||
|
has_default_value=False, default_value=0,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='Message', full_name='ReplyHeader.Message', index=1,
|
||||||
|
number=2, type=9, cpp_type=9, label=1,
|
||||||
|
has_default_value=False, default_value=b"".decode('utf-8'),
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
],
|
||||||
|
extensions=[
|
||||||
|
],
|
||||||
|
nested_types=[],
|
||||||
|
enum_types=[
|
||||||
|
],
|
||||||
|
serialized_options=None,
|
||||||
|
is_extendable=False,
|
||||||
|
syntax='proto3',
|
||||||
|
extension_ranges=[],
|
||||||
|
oneofs=[
|
||||||
|
],
|
||||||
|
serialized_start=148,
|
||||||
|
serialized_end=192,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_REPLYPAY = _descriptor.Descriptor(
|
||||||
|
name='ReplyPay',
|
||||||
|
full_name='ReplyPay',
|
||||||
|
filename=None,
|
||||||
|
file=DESCRIPTOR,
|
||||||
|
containing_type=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
fields=[
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='Header', full_name='ReplyPay.Header', index=0,
|
||||||
|
number=1, type=11, cpp_type=10, label=1,
|
||||||
|
has_default_value=False, default_value=None,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='Result', full_name='ReplyPay.Result', index=1,
|
||||||
|
number=2, type=3, cpp_type=2, label=1,
|
||||||
|
has_default_value=False, default_value=0,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
],
|
||||||
|
extensions=[
|
||||||
|
],
|
||||||
|
nested_types=[],
|
||||||
|
enum_types=[
|
||||||
|
],
|
||||||
|
serialized_options=None,
|
||||||
|
is_extendable=False,
|
||||||
|
syntax='proto3',
|
||||||
|
extension_ranges=[],
|
||||||
|
oneofs=[
|
||||||
|
],
|
||||||
|
serialized_start=194,
|
||||||
|
serialized_end=250,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_REPLYGIFT = _descriptor.Descriptor(
|
||||||
|
name='ReplyGift',
|
||||||
|
full_name='ReplyGift',
|
||||||
|
filename=None,
|
||||||
|
file=DESCRIPTOR,
|
||||||
|
containing_type=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
fields=[
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='Header', full_name='ReplyGift.Header', index=0,
|
||||||
|
number=1, type=11, cpp_type=10, label=1,
|
||||||
|
has_default_value=False, default_value=None,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
_descriptor.FieldDescriptor(
|
||||||
|
name='Result', full_name='ReplyGift.Result', index=1,
|
||||||
|
number=2, type=3, cpp_type=2, label=1,
|
||||||
|
has_default_value=False, default_value=0,
|
||||||
|
message_type=None, enum_type=None, containing_type=None,
|
||||||
|
is_extension=False, extension_scope=None,
|
||||||
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
||||||
|
],
|
||||||
|
extensions=[
|
||||||
|
],
|
||||||
|
nested_types=[],
|
||||||
|
enum_types=[
|
||||||
|
],
|
||||||
|
serialized_options=None,
|
||||||
|
is_extendable=False,
|
||||||
|
syntax='proto3',
|
||||||
|
extension_ranges=[],
|
||||||
|
oneofs=[
|
||||||
|
],
|
||||||
|
serialized_start=252,
|
||||||
|
serialized_end=309,
|
||||||
|
)
|
||||||
|
|
||||||
|
_REPLYPAY.fields_by_name['Header'].message_type = _REPLYHEADER
|
||||||
|
_REPLYGIFT.fields_by_name['Header'].message_type = _REPLYHEADER
|
||||||
|
DESCRIPTOR.message_types_by_name['RequestPay'] = _REQUESTPAY
|
||||||
|
DESCRIPTOR.message_types_by_name['RequestGift'] = _REQUESTGIFT
|
||||||
|
DESCRIPTOR.message_types_by_name['ReplyHeader'] = _REPLYHEADER
|
||||||
|
DESCRIPTOR.message_types_by_name['ReplyPay'] = _REPLYPAY
|
||||||
|
DESCRIPTOR.message_types_by_name['ReplyGift'] = _REPLYGIFT
|
||||||
|
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
|
||||||
|
|
||||||
|
RequestPay = _reflection.GeneratedProtocolMessageType('RequestPay', (_message.Message,), {
|
||||||
|
'DESCRIPTOR' : _REQUESTPAY,
|
||||||
|
'__module__' : 'predict_pb2'
|
||||||
|
# @@protoc_insertion_point(class_scope:RequestPay)
|
||||||
|
})
|
||||||
|
_sym_db.RegisterMessage(RequestPay)
|
||||||
|
|
||||||
|
RequestGift = _reflection.GeneratedProtocolMessageType('RequestGift', (_message.Message,), {
|
||||||
|
'DESCRIPTOR' : _REQUESTGIFT,
|
||||||
|
'__module__' : 'predict_pb2'
|
||||||
|
# @@protoc_insertion_point(class_scope:RequestGift)
|
||||||
|
})
|
||||||
|
_sym_db.RegisterMessage(RequestGift)
|
||||||
|
|
||||||
|
ReplyHeader = _reflection.GeneratedProtocolMessageType('ReplyHeader', (_message.Message,), {
|
||||||
|
'DESCRIPTOR' : _REPLYHEADER,
|
||||||
|
'__module__' : 'predict_pb2'
|
||||||
|
# @@protoc_insertion_point(class_scope:ReplyHeader)
|
||||||
|
})
|
||||||
|
_sym_db.RegisterMessage(ReplyHeader)
|
||||||
|
|
||||||
|
ReplyPay = _reflection.GeneratedProtocolMessageType('ReplyPay', (_message.Message,), {
|
||||||
|
'DESCRIPTOR' : _REPLYPAY,
|
||||||
|
'__module__' : 'predict_pb2'
|
||||||
|
# @@protoc_insertion_point(class_scope:ReplyPay)
|
||||||
|
})
|
||||||
|
_sym_db.RegisterMessage(ReplyPay)
|
||||||
|
|
||||||
|
ReplyGift = _reflection.GeneratedProtocolMessageType('ReplyGift', (_message.Message,), {
|
||||||
|
'DESCRIPTOR' : _REPLYGIFT,
|
||||||
|
'__module__' : 'predict_pb2'
|
||||||
|
# @@protoc_insertion_point(class_scope:ReplyGift)
|
||||||
|
})
|
||||||
|
_sym_db.RegisterMessage(ReplyGift)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
_PREDICT = _descriptor.ServiceDescriptor(
|
||||||
|
name='Predict',
|
||||||
|
full_name='Predict',
|
||||||
|
file=DESCRIPTOR,
|
||||||
|
index=0,
|
||||||
|
serialized_options=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
serialized_start=311,
|
||||||
|
serialized_end=395,
|
||||||
|
methods=[
|
||||||
|
_descriptor.MethodDescriptor(
|
||||||
|
name='PayDay',
|
||||||
|
full_name='Predict.PayDay',
|
||||||
|
index=0,
|
||||||
|
containing_service=None,
|
||||||
|
input_type=_REQUESTPAY,
|
||||||
|
output_type=_REPLYPAY,
|
||||||
|
serialized_options=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
),
|
||||||
|
_descriptor.MethodDescriptor(
|
||||||
|
name='GiftDay',
|
||||||
|
full_name='Predict.GiftDay',
|
||||||
|
index=1,
|
||||||
|
containing_service=None,
|
||||||
|
input_type=_REQUESTGIFT,
|
||||||
|
output_type=_REPLYGIFT,
|
||||||
|
serialized_options=None,
|
||||||
|
create_key=_descriptor._internal_create_key,
|
||||||
|
),
|
||||||
|
])
|
||||||
|
_sym_db.RegisterServiceDescriptor(_PREDICT)
|
||||||
|
|
||||||
|
DESCRIPTOR.services_by_name['Predict'] = _PREDICT
|
||||||
|
|
||||||
|
# @@protoc_insertion_point(module_scope)
|
99
predict_pb2_grpc.py
Normal file
99
predict_pb2_grpc.py
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||||
|
"""Client and server classes corresponding to protobuf-defined services."""
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
import predict_pb2 as predict__pb2
|
||||||
|
|
||||||
|
|
||||||
|
class PredictStub(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
def __init__(self, channel):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: A grpc.Channel.
|
||||||
|
"""
|
||||||
|
self.PayDay = channel.unary_unary(
|
||||||
|
'/Predict/PayDay',
|
||||||
|
request_serializer=predict__pb2.RequestPay.SerializeToString,
|
||||||
|
response_deserializer=predict__pb2.ReplyPay.FromString,
|
||||||
|
)
|
||||||
|
self.GiftDay = channel.unary_unary(
|
||||||
|
'/Predict/GiftDay',
|
||||||
|
request_serializer=predict__pb2.RequestGift.SerializeToString,
|
||||||
|
response_deserializer=predict__pb2.ReplyGift.FromString,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PredictServicer(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
def PayDay(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def GiftDay(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
|
||||||
|
def add_PredictServicer_to_server(servicer, server):
|
||||||
|
rpc_method_handlers = {
|
||||||
|
'PayDay': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.PayDay,
|
||||||
|
request_deserializer=predict__pb2.RequestPay.FromString,
|
||||||
|
response_serializer=predict__pb2.ReplyPay.SerializeToString,
|
||||||
|
),
|
||||||
|
'GiftDay': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.GiftDay,
|
||||||
|
request_deserializer=predict__pb2.RequestGift.FromString,
|
||||||
|
response_serializer=predict__pb2.ReplyGift.SerializeToString,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
|
'Predict', rpc_method_handlers)
|
||||||
|
server.add_generic_rpc_handlers((generic_handler,))
|
||||||
|
|
||||||
|
|
||||||
|
# This class is part of an EXPERIMENTAL API.
|
||||||
|
class Predict(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def PayDay(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/Predict/PayDay',
|
||||||
|
predict__pb2.RequestPay.SerializeToString,
|
||||||
|
predict__pb2.ReplyPay.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def GiftDay(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(request, target, '/Predict/GiftDay',
|
||||||
|
predict__pb2.RequestGift.SerializeToString,
|
||||||
|
predict__pb2.ReplyGift.FromString,
|
||||||
|
options, channel_credentials,
|
||||||
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
@ -2,3 +2,4 @@ tensorflow
|
||||||
keras
|
keras
|
||||||
numpy
|
numpy
|
||||||
pymysql
|
pymysql
|
||||||
|
grpc_tools
|
|
@ -15,12 +15,12 @@ from data import load_gift_data
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
x_train, y_train, tx_train, ty_train = load_gift_data()
|
x_train, y_train, tx_train, ty_train, input_shape = load_gift_data()
|
||||||
|
|
||||||
model = Sequential()
|
model = Sequential()
|
||||||
units = 400
|
units = 400
|
||||||
|
|
||||||
model.add(LSTM(units, activation='relu', input_shape=(4,1) ))
|
model.add(LSTM(units, activation='relu', input_shape=input_shape ))
|
||||||
model.add(Dropout(0.2))
|
model.add(Dropout(0.2))
|
||||||
|
|
||||||
model.add(Dense(1))
|
model.add(Dense(1))
|
||||||
|
@ -28,7 +28,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
model.compile(loss='mse', optimizer='adam')
|
model.compile(loss='mse', optimizer='adam')
|
||||||
|
|
||||||
model.fit(x_train, y_train, batch_size=1, epochs=50)
|
model.fit(x_train, y_train, batch_size=128, epochs=1500)
|
||||||
model.save("./predict_gift")
|
model.save("./predict_gift")
|
||||||
|
|
||||||
p_data = model.predict(tx_train)
|
p_data = model.predict(tx_train)
|
||||||
|
|
13
train_pay.py
13
train_pay.py
|
@ -5,6 +5,7 @@ from keras.layers import Dense, Dropout, Embedding
|
||||||
from keras.layers import InputLayer
|
from keras.layers import InputLayer
|
||||||
from keras.layers import LSTM
|
from keras.layers import LSTM
|
||||||
from keras import backend
|
from keras import backend
|
||||||
|
from keras.layers.recurrent import SimpleRNN
|
||||||
|
|
||||||
import pymysql
|
import pymysql
|
||||||
import pickle
|
import pickle
|
||||||
|
@ -16,18 +17,20 @@ from data import load_pay_data
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
x_train, y_train, tx_train, ty_train = load_pay_data(80)
|
x_train, y_train, tx_train, ty_train, input_shape = load_pay_data(80)
|
||||||
|
|
||||||
|
|
||||||
model = Sequential()
|
model = Sequential()
|
||||||
units = 500
|
units = 500
|
||||||
model.add(LSTM(units, activation='relu', input_shape=(3,1)))
|
model.add(LSTM(units, activation='relu', dropout=0.1, input_shape=input_shape))
|
||||||
model.add(Dropout(0.3))
|
|
||||||
|
# model.add(SimpleRNN(units, activation='relu'))
|
||||||
|
# model.add(Dropout(0.1))
|
||||||
model.add(Dense(1))
|
model.add(Dense(1))
|
||||||
model.summary()
|
model.summary()
|
||||||
model.compile(loss='mse', optimizer='adam')
|
model.compile(loss = 'mse', optimizer = 'adam')
|
||||||
|
|
||||||
model.fit(x_train, y_train, batch_size=1, epochs=50)
|
model.fit(x_train, y_train, batch_size=96, epochs=1200)
|
||||||
model.save("./predict_pay")
|
model.save("./predict_pay")
|
||||||
|
|
||||||
p_data = model.predict(tx_train)
|
p_data = model.predict(tx_train)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user