import { useCreateTicket } from '@axo/shared/data-access/hooks';
import {
  DataAccessContext,
  getExpiryFromJwt,
  isJwtExpired,
  useAPI,
} from '@axo/shared/data-access/provider';
import { useCallback, useContext, useEffect, useReducer, useRef } from 'react';
import { CustomWSStatus, Status, WSReadyState } from './status';
import { IPublishFunction } from './types';

type WebSocketState = {
  status: Status;
  attemptsCount: number;
};

type WebSocketAction =
  | { type: 'INITIATED' }
  | { type: 'OPENED' }
  | { type: 'CLOSED' }
  | { type: 'ERRORED' }
  | { type: 'ATTEMPTED' }
  | { type: 'RESETTED' };

const MAX_ATTEMPTS = 3;

const initialState = {
  status: CustomWSStatus.None,
  attemptsCount: 0,
};

function websocketReducer(state: WebSocketState, action: WebSocketAction) {
  switch (action.type) {
    case 'INITIATED':
      return { ...state, status: WSReadyState.Connecting };
    case 'OPENED':
      return { ...state, status: WSReadyState.Open, attemptsCount: 0 };
    case 'CLOSED':
      return { ...state, status: WSReadyState.Closed };
    case 'ERRORED':
      return { ...state, status: CustomWSStatus.Error };
    case 'ATTEMPTED':
      return {
        ...state,
        status: WSReadyState.Connecting,
        attemptsCount: state.attemptsCount + 1,
      };
    case 'RESETTED':
      return {
        ...state,
        status: CustomWSStatus.None,
        attemptsCount: 0,
      };
    default:
      throw new Error(`Unhandled action type`);
  }
}

export function useWebSocket(publish: IPublishFunction) {
  const { mutateAsync: createTicket } = useCreateTicket();
  const {
    state: {
      user: { token },
    },
  } = useContext(DataAccessContext);
  const {
    url: { ws: wsURL },
  } = useAPI();

  const [state, dispatch] = useReducer(websocketReducer, initialState);
  const websocket = useRef<WebSocket | null>(null);
  const retryTimeoutRef = useRef<number | NodeJS.Timeout | null>(null);
  const authTimeoutRef = useRef<number | NodeJS.Timeout | null>(null);
  const lastConnectionRef = useRef<{
    token: string | null;
    wsURL?: string | null;
  }>({
    token: null,
    wsURL: null,
  });

  const createWSConnection = useCallback(
    async (wsURL: string | undefined) => {
      dispatch({ type: 'INITIATED' });

      try {
        const ticket = await createTicket();
        if (!ticket || !wsURL) throw new Error('Missing ticket or WS URL');

        const url = new URL(wsURL);
        url.searchParams.set('ticket', ticket.ID);
        const ws = new WebSocket(url.toString());

        websocket.current = ws;

        ws.onopen = () => {
          dispatch({ type: 'OPENED' });
        };

        ws.onmessage = (event) => {
          const messageData = JSON.parse(event.data);
          publish({
            source: messageData.source,
            code: messageData.code,
            latestMessage: messageData,
            status: ws.readyState,
          });
        };
        ws.onclose = () => {
          dispatch({ type: 'CLOSED' });
        };
        ws.onerror = () => {
          dispatch({ type: 'ERRORED' });
          attemptRetry(wsURL);
        };
      } catch (error) {
        dispatch({ type: 'ERRORED' });
        attemptRetry(wsURL);
      }
    },
    [createTicket, publish]
  );

  const closeWSConnection = () => {
    if (websocket.current) {
      websocket.current.onopen = null;
      websocket.current.onmessage = null;
      websocket.current.onclose = null;
      websocket.current.onerror = null;

      websocket.current.close();
      websocket.current = null;

      dispatch({ type: 'CLOSED' });
    }
  };

  const setupAuthTimeoutWSConnection = (token: string) => {
    if (authTimeoutRef.current) clearTimeout(authTimeoutRef.current);

    if (isJwtExpired(token)) return;

    const timeout = getExpiryFromJwt(token)! * 1000 - Date.now();
    authTimeoutRef.current = setTimeout(closeWSConnection, timeout);
  };

  const attemptRetry = useCallback(
    (wsURL: string | undefined) => {
      if (retryTimeoutRef.current) clearTimeout(retryTimeoutRef.current);

      if (state.attemptsCount < MAX_ATTEMPTS) {
        dispatch({ type: 'ATTEMPTED' });
        retryTimeoutRef.current = setTimeout(
          () => createWSConnection(wsURL),
          1000 * state.attemptsCount
        ); // Exponential backoff
      }
    },
    [state.attemptsCount, createWSConnection]
  );

  const manualRetry = () => {
    if ([WSReadyState.Closed, CustomWSStatus.Error].includes(state.status)) {
      dispatch({ type: 'RESETTED' });
      createWSConnection(wsURL);
    }
  };

  useEffect(() => {
    if (!token || isJwtExpired(token)) return;

    const { token: lastToken, wsURL: lastWsURL } = lastConnectionRef.current;
    if (token !== lastToken || wsURL !== lastWsURL) {
      if (lastToken && lastWsURL) closeWSConnection();

      createWSConnection(wsURL);
      setupAuthTimeoutWSConnection(token);

      lastConnectionRef.current = { token, wsURL };
    }
  }, [token, wsURL, createWSConnection]);

  useEffect(() => {
    return () => {
      if (retryTimeoutRef.current) clearTimeout(retryTimeoutRef.current);
      if (authTimeoutRef.current) clearTimeout(authTimeoutRef.current);
      closeWSConnection();
    };
  }, []);

  return { status: state.status, manualRetry };
}
