import copy
import socket
import struct
import sys
import threading
import traceback
import weakref
from typing import Tuple, Optional

from marionette.message_handler import MessageHandler
from marionette.preset_storage import Preset
from marionette.stream import Stream

PROTOCOL_VERSION=5

def _close(wants_to_stop: threading.Event, wake: threading.Event) -> None:
    wants_to_stop.set()
    wake.set()


class Client:
    def __init__(self, message_handler: MessageHandler) -> None:
        """
        Note: the message_handler will be passed to the background thread,
        and should not be modified outside of this class.
        """
        self._thread = threading.Thread(target=self._run, args=[message_handler])

        self._wants_to_stop = threading.Event()
        self._wake = threading.Event()
        self._lock = threading.Lock()
        self._wants_to_connect = False
        self._target_endpoint: Optional[Tuple[str, int]] = None
        self._preset: Optional[Preset] = None

        self._finalizer = weakref.finalize(self, _close, self._wants_to_stop, self._wake)
        self._thread.start()

    @property
    def closed(self) -> bool:
        return not self._finalizer.alive

    def close(self) -> None:
        self._finalizer()

    def start(self, ip: str, port: int, preset: Preset):
        assert not self._wants_to_connect, 'Already started.'
        with self._lock:
            self._target_endpoint = (ip, port)
            self._preset = preset
            self._wants_to_connect = True
        self._wake.set()

    def stop(self) -> None:
        assert self._wants_to_connect, 'Not started.'
        with self._lock:
            self._wants_to_connect = False

    def _run(self, message_handler: MessageHandler) -> None:
        stream = Stream()
        while not self._wants_to_stop.is_set():
            with self._lock:
                wants_to_connect = self._wants_to_connect
                target_endpoint = self._target_endpoint
            if wants_to_connect:
                print(f'Trying to connect to {target_endpoint}...')
                assert target_endpoint is not None
                message_handler.on_connection_attempt_started()
                try:
                    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as connection:
                        connection.settimeout(15)
                        connection.connect(target_endpoint)
                        connected_to_endpoint = target_endpoint
                        stream.socket = connection
                        with self._lock:
                            # Make a copy local to this thread
                            preset = copy.deepcopy(self._preset)
                        print('Connected.')
                        message_handler.on_connected(stream, PROTOCOL_VERSION, preset)
                        while True:
                            with self._lock:
                                target_endpoint = self._target_endpoint
                                wants_to_connect = self._wants_to_connect
                            if (
                                    self._wants_to_stop.is_set()
                                    or not wants_to_connect
                                    or connected_to_endpoint != target_endpoint
                            ):
                                print(f'Disconnecting from {connected_to_endpoint}.')
                                break

                            if not message_handler.read_message(stream):
                                # Quit requested.
                                break
                        connection.shutdown(socket.SHUT_RDWR)
                    print(f'Connection to {connected_to_endpoint} closed.')
                except (socket.error, struct.error, ConnectionError, TimeoutError):
                    print('Connection error.')
                    print(traceback.format_exc(), file=sys.stderr)
                message_handler.on_disconnected()
            else:
                self._wake.wait()
                self._wake.clear()
        print('Thread exited.')
