package com.finconsgroup.itserr.marketplace.notification.bs.websocket;

import com.finconsgroup.itserr.marketplace.core.web.exception.WP2AuthenticationException;
import com.finconsgroup.itserr.marketplace.core.web.security.jwt.JwtTokenHolder;
import com.finconsgroup.itserr.marketplace.core.web.security.jwt.JwtTokenVerifier;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.lang.NonNull;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;

import java.util.HashMap;
import java.util.Map;

/**
 * <p>
 * This component implements the {@link HandshakeInterceptor} interface, and it is used to intercept WebSocket handshake requests, specifically to manage and
 * validate custom subprotocol information including JWTs (JSON Web Tokens) from the Sec-WebSocket-Protocol header.
 * </p>
 * <p>
 * The expected format of the Sec-WebSocket-Protocol is: key, value, key2, value2
 * </p>
 * <p>
 * According to the received subprotocols and their values, this class will do the following:
 * <ul>
 * <li><b>jwt, &lt;token&gt;</b>: the JWT token is set through {@link JwtTokenHolder#setToken(String)}</li>
 * </ul>
 * </p>
 */
@Slf4j
@Component
@RequiredArgsConstructor
public class WsProtocolJwtInterceptor implements HandshakeInterceptor {

    private final JwtTokenVerifier jwtTokenVerifier;

    @Override
    public boolean beforeHandshake(
            @NonNull final ServerHttpRequest request,
            @NonNull final ServerHttpResponse response,
            @NonNull final WebSocketHandler wsHandler,
            @NonNull final Map<String, Object> attributes) throws Exception {

        final String protocol = request.getHeaders().getFirst("Sec-WebSocket-Protocol");
        if (StringUtils.isNotBlank(protocol)) {

            // Read all protocol attributes
            final Map<String, String> protocolAttributes = new HashMap<>();
            final String[] tokens = StringUtils.split(protocol, ",");
            for (int i = 1; i < tokens.length; i += 2) {
                final String key = StringUtils.lowerCase(StringUtils.trim(tokens[i - 1]));
                final String value = StringUtils.trim(tokens[i]);
                if (StringUtils.isNotBlank(key) && value != null) {
                    protocolAttributes.put(key, value);
                }
            }

            // Handle JWT protocol
            if (protocolAttributes.containsKey("jwt")) {
                final String jwt = protocolAttributes.get("jwt");
                if (StringUtils.isNotBlank(jwt)) {
                    try {
                        JwtTokenHolder.setToken(jwt, jwtTokenVerifier);
                    } catch (WP2AuthenticationException e) {
                        log.error("Error validating JWT token", e);
                        return false;
                    }
                    if (JwtTokenHolder.getUserId().isEmpty()) {
                        log.error("Authorization was unsuccessful");
                        return false;
                    }
                    response.getHeaders().add("Sec-WebSocket-Protocol", "jwt");
                }
            }

        }

        return true;
    }

    @Override
    public void afterHandshake(
            @NonNull final ServerHttpRequest request,
            @NonNull final ServerHttpResponse response,
            @NonNull final WebSocketHandler wsHandler,
            final Exception exception) {
        // Do nothing
    }

}
