Source code for nanome.beta.nanome_sdk.plugin_server

import asyncio
import json
import logging
import logging.config
import os
import ssl
import sys

from nanome._internal.network.packet import Packet, PacketTypes
from nanome._internal.serializer_fields import TypeSerializer
from nanome.api.serializers import CommandMessageSerializer
from nanome.beta.nanome_sdk.logs import configure_main_process_logging
from nanome.beta.nanome_sdk.utils import convert_bytes_to_packet
from nanome.beta.nanome_sdk.session import run_session_loop_py
from nanome.util.config import str2bool

__all__ = ["PluginServer"]


logger = logging.getLogger(__name__)


KEEP_ALIVE_TIME_INTERVAL = 60.0
PLUGIN_REMOTE_LOGGING = str2bool(os.environ.get('PLUGIN_REMOTE_LOGGING', False))


[docs]class PluginServer: def __init__(self): self.plugin_id = None self._sessions = {} self.plugin_class = None self.polling_tasks = {}
[docs] async def run(self, nts_host, nts_port, plugin_name, description, plugin_class): self.plugin_class = plugin_class self.plugin_name = os.environ.get("PLUGIN_NAME") or plugin_name try: ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS) self.nts_reader, self.nts_writer = await asyncio.open_connection(nts_host, nts_port, ssl=ssl_context) self.plugin_id = await self.connect_plugin(self.plugin_name, description) configure_main_process_logging(self.nts_writer, self.plugin_id, self.plugin_name) logger.info(f"Plugin Connected. ID: {self.plugin_id}") self.keep_alive_task = asyncio.create_task(self.keep_alive(self.plugin_id)) self.poll_nts_task = asyncio.create_task(self.poll_nts()) await self.poll_nts_task except Exception as e: use_exc_info = sys.exc_info()[0] is not None logger.error(e, exc_info=use_exc_info) finally: self.keep_alive_task.cancel() self.poll_nts_task.cancel() self.nts_writer.close()
[docs] async def poll_nts(self): """Poll NTS for packets, and forward them to the plugin server.""" while True: received_bytes = await self.nts_reader.readexactly(Packet.packet_header_length) unpacked = Packet.header_unpack(received_bytes) payload_length = unpacked[4] received_bytes += await self.nts_reader.readexactly(payload_length) # logger.debug(f"Received Data from NTS. Size {len(received_bytes)}") await self.route_bytes(received_bytes)
[docs] async def connect_plugin(self, name, description): """Send a packet to NTS to register plugin.""" environ = os.environ key = environ.get("NTS_KEY", None) category = "" tags = [] has_advanced = False permissions = [] integrations = [] description = { 'name': name, 'description': description, 'category': category, 'tags': tags, 'hasAdvanced': has_advanced, 'auth': key, 'permissions': permissions, 'integrations': integrations } packet = Packet() plugin_id = 0 packet.set(0, Packet.packet_type_plugin_connection, plugin_id) packet.write_string(json.dumps(description)) pack = packet.pack() self.nts_writer.write(pack) await self.nts_writer.drain() # Wait for response containing plugin_id header = await self.nts_reader.readexactly(Packet.packet_header_length) unpacked = Packet.header_unpack(header) plugin_id = unpacked[3] return plugin_id
[docs] async def keep_alive(self, plugin_id): """Long running task to send keep alive packets to NTS.""" sleep_time = KEEP_ALIVE_TIME_INTERVAL while True: logger.debug("Sending keep alive packet.") packet = Packet() session_id = 0 packet.payload_length = 0 packet.session_id = session_id packet.plugin_id = plugin_id packet.packet_type = PacketTypes.keep_alive # packet.set(session_id, packet_type, plugin_id) pack = packet.pack() self.nts_writer.write(pack) await self.nts_writer.drain() await asyncio.sleep(sleep_time)
[docs] async def route_bytes(self, received_bytes): """Route bytes from NTS to the appropriate session.""" serializer = CommandMessageSerializer() packet = convert_bytes_to_packet(received_bytes) session_id = packet.session_id packet_type = packet.packet_type if packet_type == PacketTypes.message_to_plugin: # If session id does not exist, start a new session process if session_id not in self._sessions: received_version_table, _, _ = serializer.deserialize_command(packet.payload, None) version_table = TypeSerializer.get_best_version_table(received_version_table) await self.start_session_process(version_table, packet, self.plugin_class) else: process = self._sessions[session_id] # logger.debug(f"Writing line to session {session_id}: {len(received_bytes)} bytes") process.stdin.write(received_bytes) await process.stdin.drain() elif packet_type == PacketTypes.client_disconnection: logger.info(f"Disconnecting Session {session_id}.") if session_id in self._sessions: popen = self._sessions[session_id] popen.kill() del self._sessions[session_id] else: logger.warning(f"Session {session_id} process was already disconnected.") elif packet_type == PacketTypes.keep_alive: plugin_id = packet.plugin_id logger.debug(f"Keep Alive Packet received. Plugin id: {plugin_id}") elif packet_type == PacketTypes.plugin_list: logger.info("Plugin list happening?")
[docs] async def start_session_process(self, version_table, packet, plugin_class): plugin_id = packet.plugin_id session_id = packet.session_id logger.info(f"Starting process for Session {session_id}") env = { **os.environ, 'NANOME_VERSION_TABLE': json.dumps(version_table), } plugin_class_filepath = os.path.abspath(sys.modules[plugin_class.__module__].__file__) session_process = await asyncio.create_subprocess_exec( sys.executable, run_session_loop_py, str(plugin_id), str(session_id), self.plugin_name, plugin_class_filepath, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, cwd=os.getcwd(), env=env) connect_data = await session_process.stdout.readexactly(Packet.packet_header_length) try: unpacked = Packet.header_unpack(connect_data) except Exception: logger.error("Failed to unpack header") return payload_length = unpacked[4] connect_data += await session_process.stdout.readexactly(payload_length) # logger.debug(f"Writing line to NTS: {len(connect_data)} bytes") self.nts_writer.write(connect_data) self._sessions[session_id] = session_process self.polling_tasks[session_id] = asyncio.create_task(self.poll_session(session_process))
[docs] async def poll_session(self, process): """Poll a session process for packets, and forward them to NTS.""" while True: # Load header, and then payload try: outgoing_bytes = await process.stdout.readexactly(Packet.packet_header_length) except asyncio.IncompleteReadError: logger.debug("Incomplete read error. Ending session polling task.") break if not outgoing_bytes: logger.debug("No outgoing bytes. Ending polling task.") break _, _, _, _, payload_length = Packet.header_unpack(outgoing_bytes) outgoing_bytes += await process.stdout.readexactly(payload_length) # Pull out session logs, otherwise forward bytes to NTS packet = convert_bytes_to_packet(outgoing_bytes) if packet.packet_type == PacketTypes.live_logs: asyncio.create_task(self.handle_session_log_packet(packet)) else: # logger.debug(f"Writing line to NTS: {len(outgoing_bytes)} bytes") self.nts_writer.write(outgoing_bytes)
[docs] async def handle_session_log_packet(self, packet): """If the packet is a log message, use loggers in current process to handle it. This provides a way of rendering logs from session processes, without sending them directly to stdout. """ session_logger = logging.getLogger("sessions") if packet.packet_type == PacketTypes.live_logs: # Create a log record using values from gelf dict gelf_dict = json.loads(packet.payload.decode("utf-8")) level = logging._nameToLevel[gelf_dict.get("level_name")] logrecord_args = { "name": gelf_dict.get("host"), # or map to some other field if appropriate "level": level, "pathname": gelf_dict.get("file"), "lineno": gelf_dict.get("line"), "msg": gelf_dict.get("short_message"), "args": (), # not provided in GELF "exc_info": None, # not provided in GELF } logrecord = logging.LogRecord(**logrecord_args) logrecord.processName = gelf_dict.get("_process_name") session_logger.handle(logrecord) if PLUGIN_REMOTE_LOGGING: self.nts_writer.write(packet.pack())