/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.rerank.context;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ObjectPath;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder;

public class QueryContextSourceFetcher
implements ContextSourceFetcher {
    @Generated
    private static final Logger log = LogManager.getLogger(QueryContextSourceFetcher.class);
    public static final String NAME = "query_context";
    public static final String QUERY_TEXT_FIELD = "query_text";
    public static final String QUERY_TEXT_PATH_FIELD = "query_text_path";
    public static final Integer MAX_QUERY_PATH_STRLEN = 1000;
    private final ClusterService clusterService;

    @Override
    public void fetchContext(SearchRequest searchRequest, SearchResponse searchResponse, ActionListener<Map<String, Object>> listener) {
        try {
            List exts = searchRequest.source().ext();
            Map<String, Object> params = RerankSearchExtBuilder.fromExtBuilderList(exts).getParams();
            HashMap<String, String> rerankContext = new HashMap<String, String>();
            if (!params.containsKey(NAME)) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "must specify %s", NAME));
            }
            Object ctxObj = params.remove(NAME);
            if (!(ctxObj instanceof Map)) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must be a map", NAME));
            }
            Map ctxMap = (Map)ctxObj;
            if (ctxMap.containsKey(QUERY_TEXT_FIELD)) {
                if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) {
                    throw new IllegalArgumentException(String.format(Locale.ROOT, "Cannot specify both \"%s\" and \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD));
                }
                rerankContext.put(QUERY_TEXT_FIELD, (String)ctxMap.get(QUERY_TEXT_FIELD));
            } else if (ctxMap.containsKey(QUERY_TEXT_PATH_FIELD)) {
                String path = (String)ctxMap.get(QUERY_TEXT_PATH_FIELD);
                this.validatePath(path);
                Map<String, Object> map = QueryContextSourceFetcher.requestToMap(searchRequest);
                Object queryText = ObjectPath.eval((String)path, map);
                if (!(queryText instanceof String)) {
                    throw new IllegalArgumentException(String.format(Locale.ROOT, "%s must point to a string field", QUERY_TEXT_PATH_FIELD));
                }
                rerankContext.put(QUERY_TEXT_FIELD, (String)queryText);
            } else {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "Must specify either \"%s\" or \"%s\"", QUERY_TEXT_FIELD, QUERY_TEXT_PATH_FIELD));
            }
            listener.onResponse(rerankContext);
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    @Override
    public String getName() {
        return NAME;
    }

    private static Map<String, Object> requestToMap(SearchRequest request) throws IOException {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        XContentBuilder builder = XContentType.CBOR.contentBuilder((OutputStream)baos);
        request.source().toXContent(builder, ToXContent.EMPTY_PARAMS);
        builder.close();
        ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
        XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, (InputStream)bais);
        Map map = parser.map();
        return map;
    }

    private void validatePath(String path) throws IllegalArgumentException {
        if (path == null || path.isEmpty()) {
            return;
        }
        if (path.length() > MAX_QUERY_PATH_STRLEN) {
            log.error(String.format(Locale.ROOT, "invalid %s due to too many characters: %s", QUERY_TEXT_PATH_FIELD, path));
            throw new IllegalArgumentException(String.format(Locale.ROOT, "%s exceeded the maximum path length of %d characters", QUERY_TEXT_PATH_FIELD, MAX_QUERY_PATH_STRLEN));
        }
        if ((long)path.split("\\.").length > (Long)MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(this.clusterService.getSettings())) {
            log.error(String.format(Locale.ROOT, "invalid %s due to too many nested fields: %s", QUERY_TEXT_PATH_FIELD, path));
            throw new IllegalArgumentException(String.format(Locale.ROOT, "%s exceeded the maximum path length of %d nested fields", QUERY_TEXT_PATH_FIELD, MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(this.clusterService.getSettings())));
        }
    }

    @Generated
    public QueryContextSourceFetcher(ClusterService clusterService) {
        this.clusterService = clusterService;
    }
}

