/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.analytics.expansion;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import org.apache.cassandra.analytics.DataGenerationUtils;
import org.apache.cassandra.analytics.ResiliencyTestBase;
import org.apache.cassandra.analytics.TestConsistencyLevel;
import org.apache.cassandra.analytics.TestUninterruptibles;
import org.apache.cassandra.distributed.api.Feature;
import org.apache.cassandra.distributed.api.ICluster;
import org.apache.cassandra.distributed.api.IInstance;
import org.apache.cassandra.sidecar.testing.QualifiedName;
import org.apache.cassandra.spark.bulkwriter.WriterOptions;
import org.apache.cassandra.testing.ClusterBuilderConfiguration;
import org.apache.cassandra.testing.IClusterExtension;
import org.apache.cassandra.testing.utils.ClusterUtils;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.params.provider.Arguments;

abstract class JoiningTestBase
extends ResiliencyTestBase {
    Dataset<Row> df;
    Map<IInstance, Set<String>> expectedInstanceData;
    List<IInstance> newInstances;

    JoiningTestBase() {
    }

    protected void runJoiningTestScenario(TestConsistencyLevel cl) {
        QualifiedName table = JoiningTestBase.uniqueTestTableFullName("spark_test", cl.readCL, cl.writeCL);
        this.bulkWriterDataFrameWriter(this.df, table).option(WriterOptions.BULK_WRITER_CL.name(), cl.writeCL.name()).save();
        this.validateData(table, cl.readCL, 1000);
        this.validateNodeSpecificData(table, this.expectedInstanceData);
    }

    @Override
    protected void beforeTestStart() {
        super.beforeTestStart();
        SparkSession spark = this.getOrCreateSparkSession();
        this.df = DataGenerationUtils.generateCourseData(spark, 1000);
        this.expectedInstanceData = this.generateExpectedInstanceData((ICluster<? extends IInstance>)this.cluster, this.newInstances, 1000);
    }

    protected void afterClusterProvisioned() {
        ClusterBuilderConfiguration configuration = this.testClusterConfiguration();
        this.newInstances = JoiningTestBase.addNewInstances((IClusterExtension<? extends IInstance>)this.cluster, configuration.newNodesPerDc, configuration.dcCount);
        TestUninterruptibles.awaitUninterruptiblyOrThrow(this.transitioningStateStart(), 2L, TimeUnit.MINUTES);
        this.newInstances.forEach(instance -> this.cluster.awaitRingState(instance, instance, "Joining"));
    }

    protected void completeTransitionsAndValidateWrites(CountDownLatch transitionalStateEnd, Stream<Arguments> testInputs, boolean failureExpected) {
        long count = transitionalStateEnd.getCount();
        int i = 0;
        while ((long)i < count) {
            transitionalStateEnd.countDown();
            ++i;
        }
        testInputs.forEach(arguments -> {
            TestConsistencyLevel cl = (TestConsistencyLevel)arguments.get()[0];
            QualifiedName tableName = JoiningTestBase.uniqueTestTableFullName("spark_test", cl.readCL, cl.writeCL);
            this.validateData(tableName, cl.readCL, 1000);
            this.validateNodeSpecificData(tableName, this.expectedInstanceData);
        });
        if (failureExpected) {
            for (IInstance joiningNode : this.newInstances) {
                Optional<ClusterUtils.RingInstanceDetails> joiningNodeDetails = this.getMatchingInstanceFromRing(this.cluster.get(1), joiningNode);
                joiningNodeDetails.ifPresent(ringInstanceDetails -> Assertions.assertThat((String)ringInstanceDetails.getState()).isNotEqualTo((Object)"Normal"));
            }
        }
    }

    protected abstract CountDownLatch transitioningStateStart();

    private static List<IInstance> addNewInstances(IClusterExtension<? extends IInstance> cluster, int newNodesPerDc, int numDcs) {
        ArrayList<IInstance> newInstances = new ArrayList<IInstance>();
        for (int i = 0; i < newNodesPerDc; ++i) {
            int dcNodeIdx = 1;
            for (int dc = 1; dc <= numDcs; ++dc) {
                IInstance dcNode = cluster.get(dcNodeIdx++);
                IInstance newInstance = cluster.addInstance(dcNode.config().localDatacenter(), dcNode.config().localRack(), inst -> {
                    inst.set("auto_bootstrap", (Object)true);
                    inst.with(new Feature[]{Feature.GOSSIP, Feature.JMX, Feature.NATIVE_PROTOCOL});
                });
                new Thread(() -> newInstance.startup(cluster.delegate())).start();
                newInstances.add(newInstance);
            }
        }
        return newInstances;
    }

    Optional<ClusterUtils.RingInstanceDetails> getMatchingInstanceFromRing(IInstance seed, IInstance instance) {
        String ipAddress = instance.broadcastAddress().getAddress().getHostAddress();
        return ClusterUtils.ring((IInstance)seed).stream().filter(i -> i.getAddress().equals(ipAddress)).findFirst();
    }
}

