import socket
import struct
from typing import cast

import select


class MessageTooLargeError(ConnectionError):
    pass


class Stream:
    def __init__(
        self,
        s: socket.socket | None = None,
        initial_capacity: int = 1024,
        max_capacity: int = 10 * 1024 * 1024,
    ) -> None:
        assert initial_capacity <= max_capacity
        self.socket = s
        self._max_buffer_capacity = max_capacity
        self._buffer = bytearray(initial_capacity)

    def send_str(self, msg: str) -> None:
        encoded = msg.encode("utf-8")
        self.send_int(len(encoded))

        self._resize_buffer(len(encoded))
        self._buffer[: len(encoded)] = encoded
        self.send(len(encoded))

    def send_int(self, msg: int) -> None:
        self._resize_buffer(4)
        struct.pack_into("<i", self._buffer, 0, msg)
        self.send(4)

    def send_float(self, msg: float) -> None:
        self._resize_buffer(4)
        struct.pack_into("f", self._buffer, 0, msg)
        self.send(4)

    def send_double(self, msg: float) -> None:
        self._resize_buffer(8)
        struct.pack_into("d", self._buffer, 0, msg)
        self.send(8)

    def send_bool(self, msg: bool) -> None:
        self._resize_buffer(1)
        self._buffer[0] = 1 if msg else 0
        self.send(1)

    def send(self, length: int) -> None:
        assert self.socket is not None

        total = 0
        while total < length:
            if not self.can_write():
                raise TimeoutError()
            sent = self.socket.send(self._buffer[total:length])
            if sent == 0:
                raise ConnectionResetError()
            total += sent

    def receive_str(self) -> str:
        length = self.receive_int()
        encoded = self.receive_bytes(length)
        return str(encoded, encoding="utf-8")

    def receive_int(self) -> int:
        return cast(int, struct.unpack("<i", self.receive_bytes(4))[0])

    def receive_float(self) -> float:
        return cast(float, struct.unpack("f", self.receive_bytes(4))[0])

    def receive_double(self) -> float:
        return cast(float, struct.unpack("d", self.receive_bytes(8))[0])

    def receive_bool(self) -> bool:
        received = self.receive_bytes(1)
        if received == bytes([0]):
            return False
        if received == bytes([1]):
            return True
        raise struct.error(f"Invalid bool {received}.")

    def receive_byte(self) -> memoryview:
        return self.receive_bytes(1)

    def receive_bytes(self, length: int) -> memoryview:
        assert self.socket is not None

        self._resize_buffer(length)

        total = 0
        while True:
            if total >= length:
                return memoryview(self._buffer)[:length]
            if not self.can_read():
                raise TimeoutError()
            bytes_received = self.socket.recv_into(self._buffer, length - total)
            if bytes_received == 0:
                raise ConnectionResetError()
            total += bytes_received

    def can_read(self, timeout: int = 15) -> bool:
        can_read, _can_write, _has_exception = select.select(
            [self.socket], [], [], timeout
        )
        return self.socket in can_read

    def can_write(self, timeout: int = 15) -> bool:
        _can_read, can_write, _has_exception = select.select(
            [], [self.socket], [], timeout
        )
        return self.socket in can_write

    def _resize_buffer(self, min_length: int) -> None:
        while len(self._buffer) < min_length:
            new_size = min(len(self._buffer) * 2, self._max_buffer_capacity)
            if new_size < min_length and new_size == self._max_buffer_capacity:
                raise MessageTooLargeError()
            self._buffer.extend(bytearray(new_size - len(self._buffer)))
