163 lines
4.8 KiB
Python
163 lines
4.8 KiB
Python
import zmq
|
|
import msgpack
|
|
import msgpack_numpy as m
|
|
|
|
import logging
|
|
import time
|
|
|
|
from typing import Any, Callable
|
|
import zstandard as zstd
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
compresser = zstd.ZstdCompressor(level=12)
|
|
decompresser = zstd.ZstdDecompressor()
|
|
|
|
|
|
def _pack(data: Any) -> bytes:
|
|
return compresser.compress(msgpack.packb(data, default=m.encode, use_bin_type=True))
|
|
|
|
|
|
def _unpack(data: bytes) -> Any:
|
|
return msgpack.unpackb(
|
|
decompresser.decompress(data), object_hook=m.decode, raw=False
|
|
)
|
|
|
|
|
|
class Server:
|
|
def __init__(self, host: str = "*", port: int = 5555):
|
|
self.host = host
|
|
self.port = port
|
|
|
|
self.context = zmq.Context()
|
|
self.socket = self.context.socket(zmq.REP)
|
|
self.socket.bind(f"tcp://{self.host}:{self.port}")
|
|
logger.info(f"Server started at tcp://{self.host}:{self.port}")
|
|
|
|
self.endpoints: dict[str, Callable[[Any], Any]] = {}
|
|
|
|
def register_endpoint(self, command: str, func: Callable[[Any], Any]):
|
|
self.endpoints[command] = func
|
|
logger.info(f"Registered endpoint: {command} -> {func}")
|
|
|
|
def return_error(self, message: str) -> None:
|
|
self.socket.send(_pack({"status": "error", "data": message}))
|
|
|
|
def return_ok(self, data: Any) -> None:
|
|
self.socket.send(_pack({"status": "ok", "data": data}))
|
|
|
|
def handle_once(self) -> None:
|
|
message = self.socket.recv()
|
|
message = _unpack(message)
|
|
|
|
cmd = message.get("command")
|
|
data = message.get("data")
|
|
|
|
logger.info("Received Command: %s", cmd)
|
|
|
|
handler = self.endpoints.get(cmd)
|
|
|
|
if handler is not None:
|
|
try:
|
|
if data is None:
|
|
response = handler()
|
|
else:
|
|
response = handler(data)
|
|
self.return_ok(response)
|
|
except Exception as e:
|
|
logger.error(f"Error handling command {cmd}: {e}")
|
|
self.return_error(str(e))
|
|
else:
|
|
logger.warning(f"Unknown command: {cmd}")
|
|
self.return_error(f"Unknown command: {cmd}")
|
|
|
|
def loop_forever(self):
|
|
try:
|
|
while True:
|
|
self.handle_once()
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Server shutting down...")
|
|
|
|
finally:
|
|
self.socket.close()
|
|
self.context.term()
|
|
|
|
|
|
class Client:
|
|
def __init__(self, host: str = "localhost", port: int = 5555):
|
|
self.context = zmq.Context()
|
|
self.socket = self.context.socket(zmq.REQ)
|
|
self.socket.connect(f"tcp://{host}:{port}")
|
|
logger.info(f"Client connected to tcp://{host}:{port}")
|
|
|
|
def call_endpoint(self, command: str, data=None):
|
|
self.socket.send(_pack({"command": command, "data": data}))
|
|
message = self.socket.recv()
|
|
message = _unpack(message)
|
|
|
|
if message.get("status") == "ok":
|
|
return message.get("data")
|
|
else:
|
|
logger.error(f"Error from server: {message.get('data')}")
|
|
raise Exception(f"Error from server: {message.get('data')}")
|
|
|
|
|
|
def freq_control(freq: int = 25):
|
|
def decorator(func):
|
|
def wrapper(*args, **kwargs):
|
|
start_time = time.time()
|
|
result = func(*args, **kwargs)
|
|
end_time = time.time()
|
|
elapsed_time = end_time - start_time
|
|
# logger.info(f"'{func.__name__}' tooks {elapsed_time * 1000:.2f} ms")
|
|
time.sleep(max(0, (1.0 / freq) - elapsed_time))
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
from time import sleep
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
)
|
|
|
|
assert (len(sys.argv) == 2) and ((mode := sys.argv[1]) in ("server", "client")), (
|
|
"Usage: python service.py [server|client]"
|
|
)
|
|
|
|
## Protocol:
|
|
# Request: { "command": str, "data": Any }
|
|
# Response: { "status": "ok" | "error", "data": Any if status=="ok" else str (ErrorMsg) }
|
|
|
|
if mode == "server":
|
|
server = Server()
|
|
server.register_endpoint("ping", lambda: "pong")
|
|
server.register_endpoint("echo", lambda x: x)
|
|
server.register_endpoint("add", lambda data: data["a"] + data["b"])
|
|
server.loop_forever()
|
|
|
|
elif mode == "client":
|
|
client = Client()
|
|
while True:
|
|
try:
|
|
response = client.call_endpoint("ping")
|
|
print(f"Response from server: {response}")
|
|
response = client.call_endpoint("echo", "Hello, World!")
|
|
print(f"Response from server: {response}")
|
|
response = client.call_endpoint("add", {"a": 5, "b": 10})
|
|
print(f"Response from server: {response}")
|
|
|
|
sleep(0.2)
|
|
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
break
|