class WebsocketClientPolicy:
"""Implements the Policy interface by communicating with a server over websocket.
See WebsocketPolicyServer for a corresponding server implementation.
"""
def __init__(
self,
host: str = "0.0.0.0",
port: Optional[int] = None,
api_key: Optional[str] = None,
allow_reconnect: bool = False,
) -> None:
self._uri = f"wss://{host}" if int(port) == 443 else f"ws://{host}"
if port is not None:
self._uri += f":{port}"
self._packer = Packer()
self._api_key = api_key
self._ws, self._server_metadata = None, None
self._allow_reconnect = allow_reconnect
def get_server_metadata(self) -> Dict:
return self._server_metadata
def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
# TODO [Wensi]: use URL parser instead of this
# Extract host and port for health check
host_port = self._uri.replace("ws://", "").replace("wss://", "")
if ":" in host_port:
host, port = host_port.split(":")
health_url = f"https://{host}:{port}/healthz" if int(port) == 443 else f"http://{host}:{port}/healthz"
else:
health_url = f"http://{host_port}/healthz"
# First, wait for the health check to pass
while True:
try:
response = requests.get(health_url, timeout=2)
if response.ok:
logger.info("Health check passed, attempting websocket connection...")
break
except Exception:
pass
logger.info(f"Health check failed, waiting for server at {health_url}...")
time.sleep(5)
# Now attempt websocket connection (rest of the code remains the same)
while True:
try:
headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None
conn = websockets.sync.client.connect(
self._uri,
compression=None,
max_size=None,
additional_headers=headers,
ping_interval=60,
ping_timeout=300,
)
metadata = unpackb(conn.recv())
logger.info("Connected to server!")
return conn, metadata
except (ConnectionRefusedError, websockets.exceptions.InvalidMessage, EOFError) as e:
logger.info(f"Websocket connection failed ({e}), retrying...")
time.sleep(5)
def act(self, obs: Dict) -> th.Tensor:
if self._ws is None:
self._ws, self._server_metadata = self._wait_for_server()
data = self._packer.pack(obs)
while True:
try:
self._ws.send(data)
response = self._ws.recv()
break
except websockets.exceptions.ConnectionClosedError:
if self._allow_reconnect:
logger.warning("Connection to server lost, attempting to reconnect...")
self._ws, self._server_metadata = self._wait_for_server()
continue
raise
if isinstance(response, str):
# we're expecting bytes; if the server sends a string, it's an error.
raise RuntimeError(f"Error in inference server:\n{response}")
action_dict = unpackb(response)
try:
action_np = deepcopy(action_dict["action"])
except KeyError:
# We try getting action one more time before raising error
logger.warning("No action received from server, retrying one more time...")
self._ws.send(data)
response = self._ws.recv()
action_dict = unpackb(response)
action_np = deepcopy(action_dict["action"])
action = th.from_numpy(action_np).to(th.float32)
return action
def reset(self) -> None:
if self._ws is None:
self._ws, self._server_metadata = self._wait_for_server()
data = self._packer.pack({"reset": True})
self._ws.send(data)