39 lines
1.4 KiB
Python
39 lines
1.4 KiB
Python
|
|
|
|
import grpc
|
|
import logging
|
|
from concurrent import futures
|
|
|
|
import predict_pb2, predict_pb2_grpc
|
|
from predict_pb2 import *
|
|
from predict_pb2_grpc import *
|
|
|
|
import numpy
|
|
from keras.models import load_model
|
|
from data import load_pay_data, load_gift_data
|
|
pay_model = load_model("./predict_pay")
|
|
gift_model = load_model("./predict_gift")
|
|
|
|
class Predict(predict_pb2_grpc.PredictServicer):
|
|
def PayDay(self, request, context):
|
|
# print(request)
|
|
inputx = numpy.reshape([request.Hour, request.Coin / request.YesterdayCoin, request.Coin], (1, 3, 1))
|
|
predict_value = pay_model.predict(inputx)
|
|
return ReplyPay(Header=ReplyHeader(Code=0, Message=""), Result=int(predict_value[0][0]))
|
|
# return super().PayDay(request, context)
|
|
def GiftDay(self, request, context):
|
|
# print(request)
|
|
inputx = numpy.reshape([request.Hour, request.Coin / request.YesterdayCoin, request.Coin], (1, 3, 1))
|
|
predict_value = gift_model.predict(inputx)
|
|
return ReplyGift(Header=ReplyHeader(Code=0, Message=""), Result=int(predict_value[0][0]))
|
|
|
|
def server():
|
|
grpc_server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
|
|
predict_pb2_grpc.add_PredictServicer_to_server(Predict(), grpc_server)
|
|
grpc_server.add_insecure_port('0.0.0.0:50051')
|
|
grpc_server.start()
|
|
grpc_server.wait_for_termination()
|
|
|
|
if __name__ == '__main__':
|
|
logging.basicConfig()
|
|
server() |