"""Contains :class:`~.Connection`."""
import abc
import asyncio
import zlib
import io
from . import enums
from . import util
from . import encryption
from .types import VarInt
from .packets import (
GenericPacket,
ServerboundPacket,
ClientboundPacket,
HandshakingPacket,
StatusPacket,
LoginPacket,
PlayPacket,
serverbound,
clientbound,
)
[docs]class Connection:
"""A connection between a client and a server.
Parameters
----------
bound
Where read packets are bound. Either :mod:`~.serverbound` or :mod:`~.clientbound`.
Attributes
----------
bound
Either :class:`~.ServerboundPacket` or :class:`~.ClientboundPacket`.
Used by :meth:`gen_packet_info` to populate the :attr:`packet_info`
attribute.
packet_info : :class:`dict`
A dictionary whose keys are packet id's and whose values
are subclasses of :class:`~.Packet`. Populated by :meth:`gen_packet_info`
and used by :meth:`read_packet` to determine which id corresponds to
which subclass of :class:`~.Packet`.
read_lock : :class:`asyncio.Lock`
The lock to make sure packet reads don't overlap.
specific_reads : :class:`dict`
A dictionary whose keys are subclasses of :class:`~.Packet`
and whose values are :class:`~.AsyncValueHolder`.
Used in :meth:`dispatch_packet` to send packets to specific reads
that were made with :meth:`read_packet`.
reader : :class:`asyncio.StreamReader`
The reader for receiving packet data.
writer : :class:`asyncio.StreamWriter`
The writer for writing packet data.
comp_threshold : :class:`int`
The maximum size of a packet before it's compressed.
If less than or equal to 0, then compression is disabled.
"""
def __init__(self, bound):
self.bound = {
serverbound: ServerboundPacket,
clientbound: ClientboundPacket,
}[bound]
self.current_state = enums.State.Handshaking
self.read_lock = asyncio.Lock()
self.specific_reads = {}
self.reader = None
self.writer = None
self.comp_threshold = 0
[docs] def gen_packet_info(self, state, *, ctx=None):
"""Generates the :attr:`packet_info`.
Parameters
----------
state : :class:`~.State`
Which state the packet info is for.
ctx : :class:`~.PacketContext`, optional
Which context the packet info is for.
Returns
-------
:class:`dict`
The packet info. See :attr:`packet_info` for a more thorough description.
"""
state_class = {
enums.State.Handshaking: HandshakingPacket,
enums.State.Status: StatusPacket,
enums.State.Login: LoginPacket,
enums.State.Play: PlayPacket,
}[state]
ret = {}
for c in util.get_subclasses(state_class) & util.get_subclasses(self.bound):
id = c.get_id(ctx=ctx)
if id is not None:
ret[id] = c
return ret
@property
def ctx(self):
"""The connection's :class:`~.PacketContext`."""
return self._ctx
@ctx.setter
def ctx(self, value):
self._ctx = value
try:
self.packet_info = self.gen_packet_info(self.current_state, ctx=value)
except AttributeError:
pass
@property
def current_state(self):
"""The current :class:`~.State` of the connection."""
return self._current_state
@current_state.setter
def current_state(self, value):
self._current_state = value
try:
self.packet_info = self.gen_packet_info(value, ctx=self.ctx)
except AttributeError:
pass
@property
def comp_enabled(self):
"""Whether compression is enabled."""
return self.comp_threshold > 0
@comp_enabled.setter
def comp_enabled(self, value):
if not self.comp_enabled and value:
# We can't know what threshold to set to enable compression
raise ValueError("Cannot set comp_enabled to True if it is False.")
if not value:
self.comp_threshold = 0
[docs] def is_closing(self):
"""Checks if the connection is closed or being closed.
Returns
-------
:class:`bool`
Whether the connection is closed or being closed.
"""
return self.writer is not None and self.writer.is_closing()
[docs] def close(self):
"""Closes the connection.
Should be used alongside the :meth:`wait_closed` method.
"""
if self.writer is not None:
self.writer.close()
[docs] async def wait_closed(self):
"""Waits until the connection is closed."""
if self.writer is not None:
await self.writer.wait_closed()
def __del__(self):
self.close()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, exc_tb):
self.close()
await self.wait_closed()
[docs] def enable_encryption(self, shared_secret):
"""Enables encryption for the connection.
Parameters
----------
shared_secret
The shared secret, either gotten from :func:`~.gen_shared_secret`
or decrypted from :class:`~.EncryptionResponsePacket`.
"""
cipher = encryption.gen_cipher(shared_secret)
self.reader = encryption.EncryptedStream(self.reader, cipher.decryptor(), None)
self.writer = encryption.EncryptedStream(self.writer, None, cipher.encryptor())
[docs] def create_packet(self, pack_class, **kwargs):
"""Creates a packet with the connection's :attr:`ctx` attribute.
Parameters
----------
pack_class : subclass of :class:`~.Packet`
The packet to create.
**kwargs
The attributes of the packet to set and their corresponding values.
Returns
-------
:class:`~.Packet`
The created packet.
"""
return pack_class(ctx=self.ctx, **kwargs)
[docs] def dispatch_packet(self, packet):
"""Dispatches a packet to calls of :meth:`read_packet` which specified ``read_class``.
Parameters
----------
packet : :class:`~.Packet`
The packet to dispatch.
"""
to_remove = []
for pack_class, holder in self.specific_reads.items():
if isinstance(packet, pack_class):
holder.set(packet)
to_remove.append(pack_class)
for pack_class in to_remove:
self.specific_reads.pop(pack_class)
[docs] async def wait_for_incoming_packet(self, pack_class):
"""Waits for an incoming packet read with :meth:`read_packet`.
Parameters
----------
pack_class : subclass of :class:`~.Packet`
The packet to wait for.
Returns
-------
:class:`~.Packet` or ``None``
Returns ``None`` if the connection is closed, else
returns the packet.
"""
packet_holder = self.specific_reads.get(pack_class)
if packet_holder is None:
packet_holder = util.AsyncValueHolder()
self.specific_reads[pack_class] = packet_holder
while not self.is_closing():
try:
return await asyncio.wait_for(packet_holder.get(), 1)
except asyncio.TimeoutError:
pass
return None
[docs] async def decompress_packet_data(self, data):
"""Decompresses raw packet data.
Parameters
----------
data
The raw packet data.
Returns
-------
:class:`io.BytesIO`
The raw decompressed packet data.
Raises
------
:exc:`ValueError`
If compression is enabled and the data length of the packet
is greater than 0 but less than or equal to the
:attr:`comp_threshold` attribute.
"""
data = util.file_object(data)
if self.comp_enabled:
data_len = VarInt.unpack(data, ctx=self.ctx)
if data_len > 0:
if data_len <= self.comp_threshold:
self.close()
await self.wait_closed()
raise ValueError(f"Invalid data length {data_len} for compression threshold {self.comp_threshold}")
data = io.BytesIO(zlib.decompress(data.read(), bufsize=data_len))
return data
[docs] async def read_packet(self, read_class=None):
"""Reads a packet.
Parameters
----------
read_class : subclass of :class:`~.Packet`, optional
The packet you want to read. If unspecified, whatever
the next packet is will be returned. Requires this method
to be called elsewhere with ``read_class`` unspecified to work.
Returns
-------
:class:`~.Packet` or ``None``
If EOF is reached when reading the packet, then the connection
will be closed and ``None`` will be returned. Otherwise the read
packet will be returned.
"""
if read_class is not None:
return await self.wait_for_incoming_packet(read_class)
data = b""
try:
async with self.read_lock:
length_buf = b""
length = -1
while True:
length_buf += await self.reader.readexactly(1)
try:
length = VarInt.unpack(length_buf, ctx=self.ctx)
except:
continue
break
if length >= 0:
data = await self.reader.readexactly(length)
except asyncio.IncompleteReadError:
self.close()
await self.wait_closed()
return None
data = await self.decompress_packet_data(data)
id = VarInt.unpack(data, ctx=self.ctx)
pack_class = self.packet_info.get(id)
if pack_class is None:
pack_class = GenericPacket(id)
packet = pack_class.unpack(data, ctx=self.ctx)
self.dispatch_packet(packet)
return packet
[docs] def compress_packet_data(self, data):
"""Compresses raw packet data.
Parameters
----------
data : :class:`bytes`
The raw, uncompressed packet data.
Returns
-------
;class:`bytes`
The raw, potentially compressed packet data.
"""
if self.comp_enabled:
data_len = 0
if len(data) > self.comp_threshold:
data_len = len(data)
data = zlib.compress(data)
data = VarInt.pack(data_len, ctx=self.ctx) + data
return data
[docs] async def write_packet(self, packet, **kwargs):
"""Writes a packet.
Parameters
----------
packet : subclass of :class:`~.Packet` or :class:`~.Packet`
If a subclass of :class:`~.Packet`, then the packet to write
will be created by forwarding ``packet`` and ``**kwargs`` to the
:meth:`create_packet` method. Otherwise, ``packet`` is the
packet to write.
``packet`` being a subclass of :class:`~.Packet` is preferred
so that the packet is created with the correct context for
the connection.
**kwargs
The packet attributes to set and their corresponding values.
Only able to be passed if ``packet`` is a subclass of :class:`~.Packet`.
Returns
-------
:class:`~.Packet`
The written packet.
Raises
------
:exc:`TypeError`
If ``**kwargs`` is passed but ``packet`` isn't a subclass of
:class:`~.Packet`.
"""
if isinstance(packet, type):
packet = self.create_packet(packet, **kwargs)
elif len(kwargs) > 0:
raise TypeError("Packet object passed with keyword arguments")
data = packet.pack(ctx=self.ctx)
data = self.compress_packet_data(data)
data = VarInt.pack(len(data), ctx=self.ctx) + data
self.writer.write(data)
await self.writer.drain()
return packet