diff --git a/CHANGES.txt b/CHANGES.txt index 512a574614de..f53a19663cff 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,3 +1,6 @@ +trunk?? (The current trunk is on 5.1) +* Support ZSTD dictionary compression (CASSANDRA-17021) + 5.1 * Update snakeyaml to 2.4 (CASSANDRA-20928) * Update Netty to 4.1.125.Final (CASSANDRA-20925) diff --git a/conf/cassandra.yaml b/conf/cassandra.yaml index cfa1b7ff9ed8..45cac6f3ffe7 100644 --- a/conf/cassandra.yaml +++ b/conf/cassandra.yaml @@ -617,6 +617,54 @@ counter_cache_save_period: 7200s # Disabled by default, meaning all keys are going to be saved # counter_cache_keys_to_save: 100 +# Dictionary compression settings for ZSTD dictionary-based compression +# These settings control the automatic training and caching of compression dictionaries +# for tables that use ZSTD dictionary compression. + +# How often to refresh compression dictionaries across the cluster. +# During refresh, nodes will check for newer dictionary versions and update their caches. +# Min unit: s +compression_dictionary_refresh_interval: 3600s + +# Initial delay before starting the first dictionary refresh cycle after node startup. +# This prevents all nodes from refreshing simultaneously when the cluster starts. +# Min unit: s +compression_dictionary_refresh_initial_delay: 10s + +# Maximum number of compression dictionaries to cache per table. +# Each table using dictionary compression can have multiple dictionaries cached +# (current version plus recently used versions for reading older SSTables). +compression_dictionary_cache_size: 10 + +# How long to keep compression dictionaries in the cache before they expire. +# Expired dictionaries will be removed from memory but can be reloaded if needed. +# Min unit: s +compression_dictionary_cache_expire: 3600s + +# Dictionary training configuration (advanced settings) +# These settings control how compression dictionaries are trained from sample data. + +# Maximum size of a trained compression dictionary in bytes. +# Larger dictionaries may provide better compression but use more memory. +# Min unit: B +compression_dictionary_training_max_dictionary_size: 65536 + +# Maximum total size of sample data to collect for dictionary training. +# More sample data generally produces better dictionaries but takes longer to train. +# The recommended sample size is 100x the dictionary size. +# Min unit: B +compression_dictionary_training_max_total_sample_size: 10485760 + +# Enable automatic dictionary training based on sampling of write operations. +# When enabled, the system will automatically collect samples and train new dictionaries. +# Manual training via nodetool is always available regardless of this setting. +compression_dictionary_training_auto_train_enabled: false + +# Sampling rate for automatic dictionary training (1-10000). +# Value of 100 means 1% of writes are sampled. Lower values reduce overhead but may +# result in less representative sample data for dictionary training. +compression_dictionary_training_sampling_rate: 100 + # saved caches # If not set, the default directory is $CASSANDRA_HOME/data/saved_caches. # saved_caches_directory: /var/lib/cassandra/saved_caches diff --git a/src/java/org/apache/cassandra/config/Config.java b/src/java/org/apache/cassandra/config/Config.java index e931b5a9d9dc..bf032826a688 100644 --- a/src/java/org/apache/cassandra/config/Config.java +++ b/src/java/org/apache/cassandra/config/Config.java @@ -514,6 +514,17 @@ public static class SSTableConfig public volatile DurationSpec.IntSecondsBound counter_cache_save_period = new DurationSpec.IntSecondsBound("7200s"); public volatile int counter_cache_keys_to_save = Integer.MAX_VALUE; + public volatile DurationSpec.IntSecondsBound compression_dictionary_refresh_interval = new DurationSpec.IntSecondsBound("3600s"); // 1 hour - TODO: re-assess whether daily (86400s) is more appropriate + public volatile DurationSpec.IntSecondsBound compression_dictionary_refresh_initial_delay = new DurationSpec.IntSecondsBound("10s"); // 10 seconds default + public volatile int compression_dictionary_cache_size = 10; // max dictionaries per table + public volatile DurationSpec.IntSecondsBound compression_dictionary_cache_expire = new DurationSpec.IntSecondsBound("3600s"); // 1 hour default + + // Dictionary training settings + public volatile int compression_dictionary_training_max_dictionary_size = 65536; // 64KB + public volatile int compression_dictionary_training_max_total_sample_size = 10485760; // 10MB total + public volatile boolean compression_dictionary_training_auto_train_enabled = false; + public volatile int compression_dictionary_training_sampling_rate = 100; // samples 1%; using int since random.nextInt is generally faster than random.nextDouble + public DataStorageSpec.LongMebibytesBound paxos_cache_size = null; public DataStorageSpec.LongMebibytesBound consensus_migration_cache_size = null; diff --git a/src/java/org/apache/cassandra/config/DatabaseDescriptor.java b/src/java/org/apache/cassandra/config/DatabaseDescriptor.java index ec76193e1046..052d6d968341 100644 --- a/src/java/org/apache/cassandra/config/DatabaseDescriptor.java +++ b/src/java/org/apache/cassandra/config/DatabaseDescriptor.java @@ -4361,6 +4361,47 @@ public static void setCounterCacheKeysToSave(int counterCacheKeysToSave) conf.counter_cache_keys_to_save = counterCacheKeysToSave; } + public static int getCompressionDictionaryRefreshIntervalSeconds() + { + return conf.compression_dictionary_refresh_interval.toSeconds(); + } + + public static int getCompressionDictionaryRefreshInitialDelaySeconds() + { + return conf.compression_dictionary_refresh_initial_delay.toSeconds(); + } + + public static int getCompressionDictionaryCacheSize() + { + return conf.compression_dictionary_cache_size; + } + + public static int getCompressionDictionaryCacheExpireSeconds() + { + return conf.compression_dictionary_cache_expire.toSeconds(); + } + + public static int getCompressionDictionaryTrainingMaxDictionarySize() + { + return conf.compression_dictionary_training_max_dictionary_size; + } + + public static int getCompressionDictionaryTrainingMaxTotalSampleSize() + { + return conf.compression_dictionary_training_max_total_sample_size; + } + + public static boolean getCompressionDictionaryTrainingAutoTrainEnabled() + { + return conf.compression_dictionary_training_auto_train_enabled; + } + + + public static int getCompressionDictionaryTrainingSamplingRate() + { + return conf.compression_dictionary_training_sampling_rate; + } + public static int getStreamingKeepAlivePeriod() { return conf.streaming_keep_alive_period.toSeconds(); diff --git a/src/java/org/apache/cassandra/db/ColumnFamilyStore.java b/src/java/org/apache/cassandra/db/ColumnFamilyStore.java index 0d6cd2eb9be2..839a21d55d67 100644 --- a/src/java/org/apache/cassandra/db/ColumnFamilyStore.java +++ b/src/java/org/apache/cassandra/db/ColumnFamilyStore.java @@ -83,6 +83,7 @@ import org.apache.cassandra.db.compaction.CompactionManager; import org.apache.cassandra.db.compaction.CompactionStrategyManager; import org.apache.cassandra.db.compaction.OperationType; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.db.filter.ClusteringIndexFilter; import org.apache.cassandra.db.filter.DataLimits; import org.apache.cassandra.db.lifecycle.ILifecycleTransaction; @@ -320,6 +321,7 @@ public enum FlushReason public final TopPartitionTracker topPartitions; private final SSTableImporter sstableImporter; + private final CompressionDictionaryManager compressionDictionaryManager; private volatile boolean compactionSpaceCheck = true; @@ -390,6 +392,7 @@ public void reload(TableMetadata tableMetadata) cfs.crcCheckChance = new DefaultValue<>(tableMetadata.params.crcCheckChance); compactionStrategyManager.maybeReloadParamsFromSchema(tableMetadata.params.compaction); + compressionDictionaryManager.maybeReloadFromSchema(tableMetadata.params.compression); indexManager.reload(tableMetadata); @@ -576,6 +579,7 @@ public ColumnFamilyStore(Keyspace keyspace, streamManager = new CassandraStreamManager(this); repairManager = new CassandraTableRepairManager(this); sstableImporter = new SSTableImporter(this); + compressionDictionaryManager = new CompressionDictionaryManager(this, registerBookeeping); if (DatabaseDescriptor.isClientOrToolInitialized() || SchemaConstants.isSystemKeyspace(getKeyspaceName())) topPartitions = null; @@ -733,6 +737,8 @@ public void invalidate(boolean expectMBean, boolean dropData) invalidateCaches(); if (topPartitions != null) topPartitions.close(); + + compressionDictionaryManager.close(); } /** @@ -3420,6 +3426,12 @@ public TableMetrics getMetrics() return metric; } + @Override + public CompressionDictionaryManager compressionDictionaryManager() + { + return compressionDictionaryManager; + } + public TableId getTableId() { return metadata().id; diff --git a/src/java/org/apache/cassandra/db/compaction/CompactionManager.java b/src/java/org/apache/cassandra/db/compaction/CompactionManager.java index 3ab00acfb954..4313eb82a1d4 100644 --- a/src/java/org/apache/cassandra/db/compaction/CompactionManager.java +++ b/src/java/org/apache/cassandra/db/compaction/CompactionManager.java @@ -1796,6 +1796,7 @@ public static SSTableWriter createWriter(ColumnFamilyStore cfs, .setSerializationHeader(sstable.header) .addDefaultComponents(cfs.indexManager.listIndexGroups()) .setSecondaryIndexGroups(cfs.indexManager.listIndexGroups()) + .setCompressionDictionaryManager(cfs.compressionDictionaryManager()) .build(txn, cfs); } @@ -1836,6 +1837,7 @@ public static SSTableWriter createWriterForAntiCompaction(ColumnFamilyStore cfs, .setSerializationHeader(SerializationHeader.make(cfs.metadata(), sstables)) .addDefaultComponents(cfs.indexManager.listIndexGroups()) .setSecondaryIndexGroups(cfs.indexManager.listIndexGroups()) + .setCompressionDictionaryManager(cfs.compressionDictionaryManager()) .build(txn, cfs); } diff --git a/src/java/org/apache/cassandra/db/compaction/Upgrader.java b/src/java/org/apache/cassandra/db/compaction/Upgrader.java index 9e4c4dd7502b..22913a84f612 100644 --- a/src/java/org/apache/cassandra/db/compaction/Upgrader.java +++ b/src/java/org/apache/cassandra/db/compaction/Upgrader.java @@ -85,6 +85,7 @@ private SSTableWriter createCompactionWriter(StatsMetadata metadata) .setSerializationHeader(SerializationHeader.make(cfs.metadata(), Sets.newHashSet(sstable))) .addDefaultComponents(cfs.indexManager.listIndexGroups()) .setSecondaryIndexGroups(cfs.indexManager.listIndexGroups()) + .setCompressionDictionaryManager(cfs.compressionDictionaryManager()) .build(transaction, cfs); } diff --git a/src/java/org/apache/cassandra/db/compaction/unified/ShardedMultiWriter.java b/src/java/org/apache/cassandra/db/compaction/unified/ShardedMultiWriter.java index ab5465df253c..79efadbfa37f 100644 --- a/src/java/org/apache/cassandra/db/compaction/unified/ShardedMultiWriter.java +++ b/src/java/org/apache/cassandra/db/compaction/unified/ShardedMultiWriter.java @@ -118,6 +118,7 @@ private SSTableWriter createWriter(Descriptor descriptor) .setSerializationHeader(header) .addDefaultComponents(indexGroups) .setSecondaryIndexGroups(indexGroups) + .setCompressionDictionaryManager(cfs.compressionDictionaryManager()) .build(txn, cfs); } diff --git a/src/java/org/apache/cassandra/db/compaction/writers/CompactionAwareWriter.java b/src/java/org/apache/cassandra/db/compaction/writers/CompactionAwareWriter.java index ea21f7be57e0..fd1966ad35d3 100644 --- a/src/java/org/apache/cassandra/db/compaction/writers/CompactionAwareWriter.java +++ b/src/java/org/apache/cassandra/db/compaction/writers/CompactionAwareWriter.java @@ -329,6 +329,7 @@ protected long getExpectedWriteSize() .setRepairedAt(minRepairedAt) .setPendingRepair(pendingRepair) .setSecondaryIndexGroups(cfs.indexManager.listIndexGroups()) - .addDefaultComponents(cfs.indexManager.listIndexGroups()); + .addDefaultComponents(cfs.indexManager.listIndexGroups()) + .setCompressionDictionaryManager(cfs.compressionDictionaryManager()); } } diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionary.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionary.java new file mode 100644 index 000000000000..65c5430310ee --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionary.java @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.EOFException; +import java.io.IOException; +import java.util.Objects; +import java.util.Set; +import javax.annotation.Nullable; + +import com.google.common.collect.ImmutableSet; +import com.google.common.hash.Hasher; +import com.google.common.hash.Hashing; + +import org.apache.cassandra.cql3.UntypedResultSet; +import org.apache.cassandra.io.compress.ICompressor; +import org.apache.cassandra.io.compress.IDictionaryCompressor; +import org.apache.cassandra.io.compress.ZstdDictionaryCompressor; + +public interface CompressionDictionary extends AutoCloseable +{ + /** + * Get the dictionary id + * + * @return dictionary id + */ + DictId identifier(); + + /** + * Get the raw bytes of the compression dictionary + * + * @return raw compression dictionary + */ + byte[] rawDictionary(); + + /** + * Get the kind of the compression algorithm + * + * @return compression algorithm kind + */ + default Kind kind() + { + return identifier().kind; + } + + default IDictionaryCompressor getCompressor() + { + return kind().getCompressor(this); + } + + /** + * Write compression dictionary to file + * + * @param out file output stream + * @throws IOException on any I/O exception when writing to the file + */ + default void serialize(DataOutput out) throws IOException + { + DictId dictId = identifier(); + int ordinal = dictId.kind.ordinal(); + out.writeByte(ordinal); + out.writeLong(dictId.id); + byte[] dict = rawDictionary(); + out.writeInt(dict.length); + out.write(dict); + int checksum = calculateChecksum((byte) ordinal, dictId.id, dict); + out.writeInt(checksum); + } + + /** + * A factory method to create concrete CompressionDictionary from the file content + * + * @param input file input stream + * @param manager compression dictionary manager that caches the dictionaries + * @return compression dictionary; otherwise, null if there is no dictionary + * @throws IOException on any I/O exception when reading from the file + */ + @Nullable + static CompressionDictionary deserialize(DataInput input, @Nullable CompressionDictionaryManager manager) throws IOException + { + int kindOrdinal; + try + { + kindOrdinal = input.readByte(); + } + catch (EOFException eof) + { + // no dictionary + return null; + } + + if (kindOrdinal < 0 || kindOrdinal >= Kind.values().length) + { + throw new IOException("Invalid compression dictionary kind: " + kindOrdinal); + } + Kind kind = Kind.values()[kindOrdinal]; + long id = input.readLong(); + DictId dictId = new DictId(kind, id); + + if (manager != null) + { + CompressionDictionary dictionary = manager.get(dictId); + if (dictionary != null) + { + return dictionary; + } + } + + int length = input.readInt(); + byte[] dict = new byte[length]; + input.readFully(dict); + int checksum = input.readInt(); + int calculatedChecksum = calculateChecksum((byte) kindOrdinal, id, dict); + if (checksum != calculatedChecksum) + throw new IOException("Compression dictionary checksum does not match"); + + CompressionDictionary dictionary = kind.getDictionary(dictId, dict); + + // update the dictionary manager if it exists + if (manager != null) + manager.add(dictionary); + + return dictionary; + } + + static CompressionDictionary createFromRow(UntypedResultSet.Row row) + { + String kindStr = row.getString("kind"); + long dictId = row.getLong("dict_id"); + + try + { + Kind kind = CompressionDictionary.Kind.valueOf(kindStr); + return kind.getDictionary(new DictId(kind, dictId), row.getByteArray("dict")); + } + catch (IllegalArgumentException ex) + { + throw new IllegalStateException(kindStr + " compression dictionary is not created for dict id " + dictId); + } + } + + @SuppressWarnings("UnstableApiUsage") + static int calculateChecksum(byte kindOrdinal, long dictId, byte[] dict) + { + Hasher hasher = Hashing.crc32c().newHasher(); + hasher.putByte(kindOrdinal); + hasher.putLong(dictId); + hasher.putBytes(dict); + return hasher.hash().asInt(); + } + + enum Kind + { + // Order matters: the enum ordinal is serialized + ZSTD + { + public CompressionDictionary getDictionary(DictId dictId, byte[] dict) + { + return new ZstdCompressionDictionary(dictId, dict); + } + + @Override + public IDictionaryCompressor getCompressor(CompressionDictionary dictionary) + { + assert dictionary instanceof ZstdCompressionDictionary; + return ZstdDictionaryCompressor.create((ZstdCompressionDictionary) dictionary); + } + + @Override + public ICompressionDictionaryTrainer getTrainer(String keyspaceName, String tableName, CompressionDictionaryTrainingConfig config, ICompressor compressor) + { + assert compressor instanceof ZstdDictionaryCompressor; + return new ZstdDictionaryTrainer(keyspaceName, tableName, config, ((ZstdDictionaryCompressor) compressor).compressionLevel()); + } + }; + + public static final Set ACCEPTABLE_DICTIONARY_KINDS = ImmutableSet.of(Kind.ZSTD); + + public abstract CompressionDictionary getDictionary(CompressionDictionary.DictId dictId, byte[] dict); + + public abstract IDictionaryCompressor getCompressor(CompressionDictionary dictionary); + + public abstract ICompressionDictionaryTrainer getTrainer(String keyspaceName, String tableName, CompressionDictionaryTrainingConfig config, ICompressor compressor); + } + + final class DictId + { + public final Kind kind; + public final long id; // A value of negative or 0 means no dictionary + + /** + * Creates a monotonically increasing dictionary ID by combining timestamp and dictionary ID. + *

+ * The resulting dictionary ID has the following structure: + * - Upper 32 bits: timestamp in minutes (signed int) + * - Lower 32 bits: Zstd dictionary ID (unsigned int, passed as long due to Java limitations) + *

+ * This ensures dictionary IDs are monotonically increasing over time, which helps to identify + * the latest dictionary. + *

+ * The implementation assumes that dictionary training frequency is significantly larger than + * every minute, which a healthy system should do. In the scenario when multiple dictionaries + * are trained in the same minute (only possible using manual training), there should not be + * correctness concerns since the dictionary is attached to the SSTables, but leads to performance + * hit from having too many dictionary. Therefore, such scenario should be avoided at the best. + * + * @param currentTimeMillis the current time in milliseconds + * @param dictId dictionary ID (unsigned 32-bit value represented as long) + * @return combined dictionary ID that is monotonically increasing over time + */ + static long makeDictId(long currentTimeMillis, long dictId) + { + // timestamp in minutes since Unix epoch. Good until year 6053 + long timestampMinutes = currentTimeMillis / 1000 / 60; + // Convert timestamp to long and shift to upper 32 bits + long combined = timestampMinutes << 32; + + // Add the unsigned int (already as long) to lower 32 bits + combined |= (dictId & 0xFFFFFFFFL); + + return combined; + } + + public DictId(Kind kind, long id) + { + this.kind = kind; + this.id = id; + } + + @Override + public boolean equals(Object o) + { + if (!(o instanceof DictId)) return false; + DictId dictId = (DictId) o; + return id == dictId.id && kind == dictId.kind; + } + + @Override + public int hashCode() + { + return Objects.hash(kind, id); + } + + @Override + public String toString() + { + return "DictId{" + + "kind=" + kind + + ", id=" + id + + '}'; + } + } +} diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionaryCache.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryCache.java new file mode 100644 index 000000000000..e3e636108ca5 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryCache.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nullable; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Caffeine; +import com.github.benmanes.caffeine.cache.RemovalCause; +import org.apache.cassandra.config.DatabaseDescriptor; + +/** + * Manages caching and current dictionary state for compression dictionaries. + *

+ * This class handles: + * - Local caching of compression dictionaries with automatic cleanup + * - Managing the current active dictionary for write operations + * - Thread-safe access to cached dictionaries + */ +public class CompressionDictionaryCache implements ICompressionDictionaryCache +{ + private static final Logger logger = LoggerFactory.getLogger(CompressionDictionaryCache.class); + + private final Cache cache; + private final AtomicReference currentDictionary = new AtomicReference<>(); + + public CompressionDictionaryCache() + { + Duration expiryTime = Duration.ofSeconds(DatabaseDescriptor.getCompressionDictionaryCacheExpireSeconds()); + this.cache = Caffeine.newBuilder() + .maximumSize(DatabaseDescriptor.getCompressionDictionaryCacheSize()) + .expireAfterAccess(expiryTime) + .removalListener((CompressionDictionary.DictId dictId, + CompressionDictionary dictionary, + RemovalCause cause) -> { + // Close dictionary when evicted from cache to free native resources + // SelfRefCounted ensures dictionary won't be actually closed if still referenced by compressors + if (dictionary != null) + { + try + { + dictionary.close(); + } + catch (Exception e) + { + logger.warn("Failed to close compression dictionary {}", dictId, e); + } + } + }) + .build(); + } + + @Nullable + @Override + public CompressionDictionary getCurrent() + { + return currentDictionary.get(); + } + + @Nullable + @Override + public CompressionDictionary get(CompressionDictionary.DictId dictId) + { + return cache.getIfPresent(dictId); + } + + @Override + public void add(CompressionDictionary compressionDictionary) + { + cache.put(compressionDictionary.identifier(), compressionDictionary); + } + + @Override + public void setCurrentIfNewer(@Nullable CompressionDictionary dictionary) + { + if (dictionary == null) + return; + + add(dictionary); + // Only update the current dictionary if we don't have one or the new one has a higher ID (newer) + CompressionDictionary current = currentDictionary.get(); + while ((current == null || dictionary.identifier().id > current.identifier().id) + && !currentDictionary.compareAndSet(current, dictionary)) + { + current = currentDictionary.get(); + } + } + + @Override + public synchronized void close() + { + CompressionDictionary dictionary = currentDictionary.get(); + // Close current dictionary + if (dictionary != null) + { + try + { + dictionary.close(); + } + catch (Exception e) + { + logger.warn("Failed to close current compression dictionary", e); + } + } + currentDictionary.set(null); + + // Invalidate cache - this will trigger removalListener to close all cached dictionaries + cache.invalidateAll(); + } +} diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionaryEventHandler.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryEventHandler.java new file mode 100644 index 000000000000..2bac03af5cfa --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryEventHandler.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import org.apache.cassandra.concurrent.ScheduledExecutors; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.Message; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.Verb; +import org.apache.cassandra.schema.SystemDistributedKeyspace; +import org.apache.cassandra.tcm.ClusterMetadata; +import org.apache.cassandra.utils.FBUtilities; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.concurrent.CompletableFuture; + +/** + * Handles compression dictionary events including training completion and cluster notifications. + *

+ * This class handles: + * - Broadcasting dictionary updates to cluster nodes + * - Retrieving new dictionaries when notified by other nodes + * - Managing dictionary cache updates + */ +public class CompressionDictionaryEventHandler implements ICompressionDictionaryEventHandler +{ + private static final Logger logger = LoggerFactory.getLogger(CompressionDictionaryEventHandler.class); + + private final ColumnFamilyStore cfs; + private final String keyspaceName; + private final String tableName; + private final ICompressionDictionaryCache cache; + + public CompressionDictionaryEventHandler(ColumnFamilyStore cfs, ICompressionDictionaryCache cache) + { + this.cfs = cfs; + this.keyspaceName = cfs.keyspace.getName(); + this.tableName = cfs.getTableName(); + this.cache = cache; + } + + @Override + public void onNewDictionaryTrained(CompressionDictionary.DictId dictionaryId) + { + logger.info("Notifying cluster about dictionary update for {}.{} with {}", + keyspaceName, tableName, dictionaryId); + + CompressionDictionaryUpdateMessage message = new CompressionDictionaryUpdateMessage(cfs.metadata().id, dictionaryId); + Collection allNodes = ClusterMetadata.current().directory.allJoinedEndpoints(); + // Broadcast notification using the fire-and-forget fashion + for (InetAddressAndPort node : allNodes) + { + if (node.equals(FBUtilities.getBroadcastAddressAndPort())) // skip ourself + continue; + sendNotification(node, message); + } + } + + @Override + public void onNewDictionaryAvailable(CompressionDictionary.DictId dictionaryId) + { + // Best effort to retrieve the dictionary; otherwise, the periodic task should retrieve the dictionary later + CompletableFuture.runAsync(() -> { + try + { + if (!cfs.metadata().params.compression.isDictionaryCompressionEnabled()) + { + return; + } + + CompressionDictionary dictionary = SystemDistributedKeyspace.retrieveCompressionDictionary(keyspaceName, tableName, dictionaryId); + cache.setCurrentIfNewer(dictionary); + } + catch (Exception e) + { + logger.warn("Failed to retrieve compression dictionary for {}.{}. {}", + keyspaceName, tableName, dictionaryId, e); + } + }, ScheduledExecutors.nonPeriodicTasks); + } + + // Best effort to notify the peer regarding the new dictionary being available to pull. + // If the request fails, each peer has periodic task scheduled to pull. + private void sendNotification(InetAddressAndPort target, CompressionDictionaryUpdateMessage message) + { + logger.debug("Sending dictionary update notification for {} to {}", message.dictionaryId, target); + + Message msg = Message.out(Verb.DICTIONARY_UPDATE_REQ, message); + MessagingService.instance() + .sendWithResponse(target, msg) + .addListener(future -> { + if (future.isSuccess()) + { + logger.debug("Successfully sent dictionary update notification to {}", target); + } + else + { + logger.warn("Failed to send dictionary update notification to {}", + target, future.cause()); + } + }); + } +} diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionaryManager.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryManager.java new file mode 100644 index 000000000000..e111c20bb54f --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryManager.java @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.nio.ByteBuffer; +import java.util.Map; +import javax.annotation.Nullable; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.db.compression.ICompressionDictionaryTrainer.TrainingStatus; +import org.apache.cassandra.schema.CompressionParams; +import org.apache.cassandra.schema.SystemDistributedKeyspace; +import org.apache.cassandra.utils.MBeanWrapper; +import org.apache.cassandra.utils.MBeanWrapper.OnException; + +public class CompressionDictionaryManager implements CompressionDictionaryManagerMBean, + ICompressionDictionaryCache, + ICompressionDictionaryEventHandler, + AutoCloseable +{ + private static final Logger logger = LoggerFactory.getLogger(CompressionDictionaryManager.class); + + private final String keyspaceName; + private final String tableName; + private volatile boolean mbeanRegistered; + private volatile boolean isEnabled; + + // Components + private final ICompressionDictionaryEventHandler eventHandler; + private final ICompressionDictionaryCache cache; + private final ICompressionDictionaryScheduler scheduler; + private ICompressionDictionaryTrainer trainer = null; + + public CompressionDictionaryManager(ColumnFamilyStore columnFamilyStore, boolean registerBookkeeping) + { + this.keyspaceName = columnFamilyStore.keyspace.getName(); + this.tableName = columnFamilyStore.getTableName(); + + this.isEnabled = columnFamilyStore.metadata().params.compression.isDictionaryCompressionEnabled(); + this.cache = new CompressionDictionaryCache(); + this.eventHandler = new CompressionDictionaryEventHandler(columnFamilyStore, cache); + this.scheduler = new CompressionDictionaryScheduler(keyspaceName, tableName, cache, isEnabled); + if (isEnabled) + { + // Initialize components + this.trainer = ICompressionDictionaryTrainer.create(keyspaceName, tableName, + columnFamilyStore.metadata().params.compression, + createTrainingConfig()); + trainer.setDictionaryTrainedListener(this::handleNewDictionary); + + scheduler.scheduleRefreshTask(); + + trainer.start(false); + } + + if (registerBookkeeping) + { + MBeanWrapper.instance.registerMBean(this, mbeanName(keyspaceName, tableName)); + } + mbeanRegistered = registerBookkeeping; + } + + static String mbeanName(String keyspaceName, String tableName) + { + return "org.apache.cassandra.db.compression:type=CompressionDictionaryManager" + + ",keyspace=" + keyspaceName + ",table=" + tableName; + } + + public boolean isEnabled() + { + return isEnabled; + } + + /** + * Reloads dictionary management configuration when compression parameters change. + * This method enables or disables dictionary compression based on the new parameters, + * and properly manages the lifecycle of training and refresh tasks. + * + * @param newParams the new compression parameters to apply + */ + public synchronized void maybeReloadFromSchema(CompressionParams newParams) + { + this.isEnabled = newParams.isDictionaryCompressionEnabled(); + scheduler.setEnabled(isEnabled); + if (isEnabled) + { + // Check if we need a new trainer due to compression parameter changes + boolean needsNewTrainer = shouldCreateNewTrainer(newParams); + + if (needsNewTrainer) + { + // The manual training should be cancelled if a new trainer is needed + scheduler.cancelManualTraining(); + // Close existing trainer and create a new one + if (trainer != null) + { + try + { + trainer.close(); + } + catch (Exception e) + { + logger.warn("Failed to close existing trainer for {}.{}", keyspaceName, tableName, e); + } + } + + trainer = ICompressionDictionaryTrainer.create(keyspaceName, tableName, newParams, createTrainingConfig()); + trainer.setDictionaryTrainedListener(this::handleNewDictionary); + } + + scheduler.scheduleRefreshTask(); + + // Start trainer if it exists + if (trainer != null) + { + trainer.start(false); + } + return; + } + + // Clean up when dictionary compression is disabled + try + { + close(); + } + catch (Exception e) + { + logger.warn("Failed to close CompressionDictionaryManager on disabling " + + "dictionary-based compression for table {}.{}", keyspaceName, tableName); + } + } + + /** + * Adds a sample to the dictionary trainer for learning compression patterns. + * Samples are randomly selected to avoid bias and improve dictionary quality. + * + * @param sample the sample data to potentially add for training + */ + public void addSample(ByteBuffer sample) + { + ICompressionDictionaryTrainer dictionaryTrainer = trainer; + if (dictionaryTrainer != null && dictionaryTrainer.shouldSample()) + { + dictionaryTrainer.addSample(sample); + } + } + + @Nullable + @Override + public CompressionDictionary getCurrent() + { + return cache.getCurrent(); + } + + @Override + public CompressionDictionary get(CompressionDictionary.DictId dictId) + { + return cache.get(dictId); + } + + @Override + public void add(CompressionDictionary compressionDictionary) + { + cache.add(compressionDictionary); + } + + @Override + public void setCurrentIfNewer(@Nullable CompressionDictionary dictionary) + { + cache.setCurrentIfNewer(dictionary); + } + + @Override + public void onNewDictionaryTrained(CompressionDictionary.DictId dictionaryId) + { + eventHandler.onNewDictionaryTrained(dictionaryId); + } + + @Override + public void onNewDictionaryAvailable(CompressionDictionary.DictId dictionaryId) + { + eventHandler.onNewDictionaryAvailable(dictionaryId); + } + + @Override + public synchronized void train(Map options) + { + // Validate table supports dictionary compression + if (!isEnabled) + { + throw new IllegalArgumentException("Table " + keyspaceName + '.' + tableName + " does not support dictionary compression"); + } + + if (trainer == null) + { + throw new IllegalStateException("Dictionary trainer is not available for table " + keyspaceName + '.' + tableName); + } + + // Parse and validate training options + ManualTrainingOptions trainingOptions = ManualTrainingOptions.fromStringMap(options); + + trainer.start(true); + scheduler.scheduleManualTraining(trainingOptions, trainer); + } + + @Override + public String getTrainingStatus() + { + ICompressionDictionaryTrainer dictionaryTrainer = trainer; + if (dictionaryTrainer == null) + { + return TrainingStatus.NOT_STARTED.toString(); + } + return dictionaryTrainer.getTrainingStatus().toString(); + } + + @Override + public void updateSamplingRate(int samplingRate) + { + ICompressionDictionaryTrainer dictionaryTrainer = trainer; + if (dictionaryTrainer == null) + { + throw new IllegalArgumentException("Dictionary trainer is not available for table " + keyspaceName + '.' + tableName); + } + dictionaryTrainer.updateSamplingRate(samplingRate); + } + + /** + * Close all the resources. The method can be called multiple times. + */ + @Override + public synchronized void close() + { + unregisterMbean(); + if (trainer != null) + { + closeQuitely(trainer, "CompressionDictionaryTrainer"); + trainer = null; + } + closeQuitely(cache, "CompressionDictionaryCache"); + closeQuitely(scheduler, "CompressionDictionaryScheduler"); + } + + private void handleNewDictionary(CompressionDictionary dictionary) + { + // sequence meatters; persist the new dictionary before broadcasting to others. + storeDictionary(dictionary); + onNewDictionaryTrained(dictionary.identifier()); + } + + private CompressionDictionaryTrainingConfig createTrainingConfig() + { + return CompressionDictionaryTrainingConfig + .builder() + .maxDictionarySize(DatabaseDescriptor.getCompressionDictionaryTrainingMaxDictionarySize()) + .maxTotalSampleSize(DatabaseDescriptor.getCompressionDictionaryTrainingMaxTotalSampleSize()) + .samplingRate(DatabaseDescriptor.getCompressionDictionaryTrainingSamplingRate()) + .build(); + } + + private void storeDictionary(CompressionDictionary dictionary) + { + if (!isEnabled) + { + return; + } + + SystemDistributedKeyspace.storeCompressionDictionary(keyspaceName, tableName, dictionary); + cache.setCurrentIfNewer(dictionary); + } + + /** + * Determines if a new trainer should be created based on compression parameter changes. + * A new trainer is needed when no existing trainer exists or when the existing trainer + * is not compatible with the new compression parameters. + * + * The method is (and should be) only invoked inside {@link #maybeReloadFromSchema(CompressionParams)}, + * which is guarded by synchronized. + * + * @param newParams the new compression parameters + * @return true if a new trainer should be created + */ + private boolean shouldCreateNewTrainer(CompressionParams newParams) + { + if (trainer == null) + { + return true; + } + + return !trainer.isCompatibleWith(newParams); + } + + private void unregisterMbean() + { + if (mbeanRegistered) + { + MBeanWrapper.instance.unregisterMBean(mbeanName(keyspaceName, tableName), OnException.IGNORE); + mbeanRegistered = true; + } + } + + private void closeQuitely(AutoCloseable closeable, String objectName) + { + try + { + closeable.close(); + } + catch (Exception exception) + { + logger.warn("Failed closing {}", objectName, exception); + } + } + + @VisibleForTesting + boolean isReady() + { + return trainer != null && trainer.isReady(); + } + + @VisibleForTesting + ICompressionDictionaryTrainer trainer() + { + return trainer; + } +} diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionaryManagerMBean.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryManagerMBean.java new file mode 100644 index 000000000000..9df88bd9fd8e --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryManagerMBean.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.Map; + +public interface CompressionDictionaryManagerMBean +{ + /** + * Starts sampling and training for this table. + * + * @param options options for the training process (currently unused, reserved for future extensions) + * @throws IllegalArgumentException if table doesn't support dictionary compression + * @throws IllegalStateException if training is already in progress for this table + */ + void train(Map options); + + /** + * Gets the current training status for this table. + * Enables async polling for status/completion. + * + * @return training status as string: "Not started", "In progress", "Completed", or "Failed" + */ + String getTrainingStatus(); + + /** + * Updates the sampling rate for the trainer. + * + * @param samplingRate the new sampling rate. For exmaple, 1 = sample every time (100%); + * 2 = expect sample 1/2 of data (50%), n = expect sample 1/n of data + * @throws IllegalArgumentException if sampling rate is invalid or trainer is not available + */ + void updateSamplingRate(int samplingRate); +} diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionaryScheduler.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryScheduler.java new file mode 100644 index 000000000000..907ae8b3736e --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryScheduler.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.concurrent.ScheduledExecutors; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.compression.ICompressionDictionaryTrainer.TrainingStatus; +import org.apache.cassandra.schema.SystemDistributedKeyspace; + +/** + * Manages scheduled tasks for compression dictionary operations. + *

+ * This class handles: + * - Periodic refresh of dictionaries from system tables + * - Manual training task scheduling and monitoring + * - Cleanup of scheduled tasks + */ +public class CompressionDictionaryScheduler implements ICompressionDictionaryScheduler +{ + private static final Logger logger = LoggerFactory.getLogger(CompressionDictionaryScheduler.class); + + private final String keyspaceName; + private final String tableName; + private final ICompressionDictionaryCache cache; + + private volatile ScheduledFuture scheduledRefreshTask; + private volatile ScheduledFuture scheduledManualTrainingTask; + private volatile boolean isEnabled; + + public CompressionDictionaryScheduler(String keyspaceName, + String tableName, + ICompressionDictionaryCache cache, + boolean isEnabled) + { + this.keyspaceName = keyspaceName; + this.tableName = tableName; + this.cache = cache; + this.isEnabled = isEnabled; + } + + /** + * Schedules the periodic dictionary refresh task if not already scheduled. + */ + public void scheduleRefreshTask() + { + if (scheduledRefreshTask != null) + return; + + this.scheduledRefreshTask = ScheduledExecutors.scheduledTasks.scheduleWithFixedDelay( + this::refreshDictionaryFromSystemTable, + DatabaseDescriptor.getCompressionDictionaryRefreshInitialDelaySeconds(), + DatabaseDescriptor.getCompressionDictionaryRefreshIntervalSeconds(), + TimeUnit.SECONDS + ); + } + + @Override + public void scheduleManualTraining(ManualTrainingOptions options, ICompressionDictionaryTrainer trainer) + { + if (scheduledManualTrainingTask != null) + { + throw new IllegalStateException("Training already in progress for table " + keyspaceName + '.' + tableName); + } + + int maxSamplingDurationSeconds = options.getMaxSamplingDurationSeconds(); + + logger.info("Starting manual dictionary training for {}.{} with max sampling duration: {} seconds", + keyspaceName, tableName, maxSamplingDurationSeconds); + + long deadlineMillis = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(maxSamplingDurationSeconds); + + ManualTrainingTask task = new ManualTrainingTask(deadlineMillis, trainer); + + // Check every second whether it gets enough samples and completes training + scheduledManualTrainingTask = ScheduledExecutors.scheduledTasks + .scheduleWithFixedDelay(task, 1, 1, TimeUnit.SECONDS); + } + + @Override + public void cancelManualTraining() + { + ScheduledFuture future = scheduledManualTrainingTask; + if (future != null) + { + future.cancel(false); + } + scheduledManualTrainingTask = null; + } + + /** + * Sets the enabled state of the scheduler. When disabled, refresh tasks will not execute. + * + * @param enabled whether the scheduler should be enabled + */ + @Override + public void setEnabled(boolean enabled) + { + this.isEnabled = enabled; + } + + /** + * Refreshes dictionary from system table and updates the cache. + * This method is called periodically by the scheduled refresh task. + */ + private void refreshDictionaryFromSystemTable() + { + try + { + if (!isEnabled) + { + return; + } + + CompressionDictionary dictionary = SystemDistributedKeyspace.retrieveLatestCompressionDictionary(keyspaceName, tableName); + cache.setCurrentIfNewer(dictionary); + } + catch (Exception e) + { + logger.warn("Failed to refresh compression dictionary for {}.{}", + keyspaceName, tableName, e); + } + } + + @Override + public void close() + { + if (scheduledRefreshTask != null) + { + scheduledRefreshTask.cancel(false); + scheduledRefreshTask = null; + } + + if (scheduledManualTrainingTask != null) + { + scheduledManualTrainingTask.cancel(false); + scheduledManualTrainingTask = null; + } + } + + private class ManualTrainingTask implements Runnable + { + private final long deadlineMillis; + private final ICompressionDictionaryTrainer trainer; + private boolean isTraining = false; + + private ManualTrainingTask(long deadlineMillis, ICompressionDictionaryTrainer trainer) + { + this.deadlineMillis = deadlineMillis; + this.trainer = trainer; + } + + @Override + public void run() + { + if (trainer.getTrainingStatus() == TrainingStatus.NOT_STARTED) + { + logger.warn("Trainer is not started. Stop training dictionary for table {}.{}", keyspaceName, tableName); + cancelManualTraining(); + return; + } + + long now = System.currentTimeMillis(); + // Force training if there are not enough samples, but we have hit the max sampling duration + boolean reachedDeadline = now >= deadlineMillis; + if (!isTraining && (trainer.isReady() || reachedDeadline)) + { + // Set isTraining to only enter the branch once + isTraining = true; + trainer.trainDictionaryAsync(reachedDeadline) + .whenComplete((dictionary, throwable) -> { + cancelManualTraining(); + if (throwable != null) + { + logger.error("Manual dictionary training failed for {}.{}", keyspaceName, tableName, throwable); + } + else + { + logger.info("Manual dictionary training completed for {}.{}", keyspaceName, tableName); + } + }); + } + } + } + + @VisibleForTesting + ScheduledFuture scheduledManualTrainingTask() + { + return scheduledManualTrainingTask; + } +} diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionaryTrainingConfig.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryTrainingConfig.java new file mode 100644 index 000000000000..e8a28aea61e4 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryTrainingConfig.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import com.google.common.base.Preconditions; + +/** + * Configuration for dictionary training parameters. + */ +public class CompressionDictionaryTrainingConfig +{ + public final int maxDictionarySize; + public final int maxTotalSampleSize; + public final int acceptableTotalSampleSize; + public final int samplingRate; + + private CompressionDictionaryTrainingConfig(Builder builder) + { + this.maxDictionarySize = builder.maxDictionarySize; + this.maxTotalSampleSize = builder.maxTotalSampleSize; + this.acceptableTotalSampleSize = builder.maxTotalSampleSize / 10 * 8; + this.samplingRate = builder.samplingRate; + } + + public static Builder builder() + { + return new Builder(); + } + + public static class Builder + { + private int maxDictionarySize = 65536; // 64KB default + private int maxTotalSampleSize = 10 * 1024 * 1024; // 10MB total + private int samplingRate = 100; // Sampling 1% + + public Builder maxDictionarySize(int size) + { + this.maxDictionarySize = size; + return this; + } + + public Builder maxTotalSampleSize(int size) + { + this.maxTotalSampleSize = size; + return this; + } + + public Builder samplingRate(int samplingRate) + { + this.samplingRate = samplingRate; + return this; + } + + public CompressionDictionaryTrainingConfig build() + { + Preconditions.checkArgument(maxDictionarySize > 0, "maxDictionarySize must be positive"); + Preconditions.checkArgument(maxTotalSampleSize > 0, "maxTotalSampleSize must be positive"); + Preconditions.checkArgument(samplingRate > 0, "samplingRate must be positive"); + return new CompressionDictionaryTrainingConfig(this); + } + } +} diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionaryUpdateMessage.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryUpdateMessage.java new file mode 100644 index 000000000000..a52afdd9aea0 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryUpdateMessage.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.io.IOException; + +import org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import org.apache.cassandra.io.IVersionedSerializer; +import org.apache.cassandra.io.util.DataInputPlus; +import org.apache.cassandra.io.util.DataOutputPlus; +import org.apache.cassandra.schema.TableId; + +public class CompressionDictionaryUpdateMessage +{ + public static final IVersionedSerializer serializer = new DictionaryUpdateMessageSerializer(); + + public final TableId tableId; + public final DictId dictionaryId; + + public CompressionDictionaryUpdateMessage(TableId tableId, DictId dictionaryId) + { + this.tableId = tableId; + this.dictionaryId = dictionaryId; + } + + public static class DictionaryUpdateMessageSerializer implements IVersionedSerializer + { + @Override + public void serialize(CompressionDictionaryUpdateMessage message, DataOutputPlus out, int version) throws IOException + { + TableId.serializer.serialize(message.tableId, out, version); + out.writeByte(message.dictionaryId.kind.ordinal()); + out.writeLong(message.dictionaryId.id); + } + + @Override + public CompressionDictionaryUpdateMessage deserialize(DataInputPlus in, int version) throws IOException + { + TableId tableId = TableId.serializer.deserialize(in, version); + int kindOrdinal = in.readByte(); + long dictionaryId = in.readLong(); + DictId dictId = new DictId(CompressionDictionary.Kind.values()[kindOrdinal], dictionaryId); + return new CompressionDictionaryUpdateMessage(tableId, dictId); + } + + @Override + public long serializedSize(CompressionDictionaryUpdateMessage message, int version) + { + return TableId.serializer.serializedSize(message.tableId, version) + + 1 + // byte for kind ordinal + 8; // long for dictionaryId + } + } +} diff --git a/src/java/org/apache/cassandra/db/compression/CompressionDictionaryUpdateVerbHandler.java b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryUpdateVerbHandler.java new file mode 100644 index 000000000000..f595b173024d --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/CompressionDictionaryUpdateVerbHandler.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.net.IVerbHandler; +import org.apache.cassandra.net.Message; +import org.apache.cassandra.schema.Schema; + +public class CompressionDictionaryUpdateVerbHandler implements IVerbHandler +{ + private static final Logger logger = LoggerFactory.getLogger(CompressionDictionaryUpdateVerbHandler.class); + public static final CompressionDictionaryUpdateVerbHandler instance = new CompressionDictionaryUpdateVerbHandler(); + + private CompressionDictionaryUpdateVerbHandler() {} + + @Override + public void doVerb(Message message) + { + CompressionDictionaryUpdateMessage payload = message.payload; + + try + { + ColumnFamilyStore cfs = Schema.instance.getColumnFamilyStoreInstance(payload.tableId); + if (cfs == null) + { + logger.warn("Received dictionary update for unknown table with tableId {}", payload.tableId); + return; + } + + logger.debug("Received dictionary update notification for {}.{} with dictionaryId {}", + cfs.keyspace, cfs.name, payload.dictionaryId); + CompressionDictionaryManager manager = cfs.compressionDictionaryManager(); + manager.onNewDictionaryAvailable(payload.dictionaryId); + } + catch (Exception e) + { + logger.error("Failed to process dictionary update notification for tableId {}", + payload.tableId, e); + } + } +} diff --git a/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryCache.java b/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryCache.java new file mode 100644 index 000000000000..bc60771c1476 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryCache.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import javax.annotation.Nullable; + +/** + * Interface for managing compression dictionary caching and current dictionary state. + *

+ * Implementations handle: + * - Local caching of compression dictionaries with automatic cleanup + * - Managing the current active dictionary for write operations + * - Thread-safe access to cached dictionaries + */ +public interface ICompressionDictionaryCache extends AutoCloseable +{ + /** + * Gets the current active compression dictionary. + * + * @return the current compression dictionary, or null if no dictionary is available + */ + @Nullable + CompressionDictionary getCurrent(); + + /** + * Retrieves a specific compression dictionary by its identifier. + * + * @param dictId the dictionary identifier to look up + * @return the compression dictionary with the given identifier, or null if not found in cache + */ + @Nullable + CompressionDictionary get(CompressionDictionary.DictId dictId); + + /** + * Stores a compression dictionary in the local cache. + * + * @param compressionDictionary the compression dictionary to cache + */ + void add(CompressionDictionary compressionDictionary); + + /** + * Set the provided dictionary as the current if it is newer. + * Also adds the dictionary to the cache. + * + * @param dictionary the dictionary to potentially set as current, may be null + */ + void setCurrentIfNewer(@Nullable CompressionDictionary dictionary); +} diff --git a/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryEventHandler.java b/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryEventHandler.java new file mode 100644 index 000000000000..4ed2f8f2ad07 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryEventHandler.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +public interface ICompressionDictionaryEventHandler +{ + /** + * Invoked when a new dictionary is trained + * @param dictionaryId dictionary id + */ + void onNewDictionaryTrained(CompressionDictionary.DictId dictionaryId); + + /** + * Invoked when {@link CompressionDictionaryUpdateMessage} is received indicating + * a dictionary is trained and local node should retrieve the specified dictionary + * @param dictionaryId dictionary id + */ + void onNewDictionaryAvailable(CompressionDictionary.DictId dictionaryId); +} diff --git a/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryScheduler.java b/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryScheduler.java new file mode 100644 index 000000000000..d6d8d0184d6a --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryScheduler.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +/** + * Interface for managing scheduled tasks for compression dictionary operations. + *

+ * Implementations handle: + * - Periodic refresh of dictionaries from system tables + * - Manual training task scheduling and monitoring + * - Cleanup of scheduled tasks + */ +public interface ICompressionDictionaryScheduler extends AutoCloseable +{ + /** + * Schedules the periodic dictionary refresh task if not already scheduled. + */ + void scheduleRefreshTask(); + + /** + * Schedules manual training with the specified options. + * + * @param options parsed and validated training options + * @param trainer the trainer to use + * @throws IllegalStateException if training is already in progress + */ + void scheduleManualTraining(ManualTrainingOptions options, ICompressionDictionaryTrainer trainer); + + /** + * Cancel the in-progress manual training + */ + void cancelManualTraining(); + + /** + * Sets the enabled state of the scheduler. When disabled, refresh tasks will not execute. + * + * @param enabled whether the scheduler should be enabled + */ + void setEnabled(boolean enabled); +} diff --git a/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryTrainer.java b/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryTrainer.java new file mode 100644 index 000000000000..fc24d5620714 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/ICompressionDictionaryTrainer.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.nio.ByteBuffer; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +import org.apache.cassandra.concurrent.ScheduledExecutors; +import org.apache.cassandra.io.compress.ICompressor; +import org.apache.cassandra.io.compress.IDictionaryCompressor; +import org.apache.cassandra.schema.CompressionParams; + +/** + * Interface for training compression dictionaries from sample data. + *

+ * Implementations handle: + * - Sample collection and management + * - Dictionary training lifecycle + * - Asynchronous training execution + * - Training status tracking + */ +public interface ICompressionDictionaryTrainer extends AutoCloseable +{ + /** + * Starts the trainer for collecting samples. + * + * @param manualTraining true if this is manual training, false for automatic + * @return true if the trainer is started; otherwise false. The trainer is started + * in any of those conditions: 1. trainer closed; 2. not requested for + * either manual or auto training; 3. failed to start + */ + boolean start(boolean manualTraining); + + /** + * @return true if the trainer is ready to take a new sample; otherwise, false + */ + boolean shouldSample(); + + /** + * Adds a sample to the training dataset. + * + * @param sample the sample data to add for training + */ + void addSample(ByteBuffer sample); + + /** + * Trains and produces a compression dictionary from collected samples synchronously. + * + * @param force force the dictionary training even if there are not enough samples; + * otherwise, dictionary training won't start if the trainer is not ready + * @return the trained compression dictionary + */ + CompressionDictionary trainDictionary(boolean force); + + /** + * Trains and produces a compression dictionary from collected samples asynchronously. + * + * @param force force the dictionary training even if there are not enough samples + * @return CompletableFuture that completes when training is done + */ + default CompletableFuture trainDictionaryAsync(boolean force) + { + return CompletableFuture.supplyAsync(() -> trainDictionary(force), ScheduledExecutors.nonPeriodicTasks); + } + + /** + * @return true if enough samples have been collected for training + */ + boolean isReady(); + + /** + * Clears all collected samples and resets trainer state. + */ + void reset(); + + /** + * @return the current training status + */ + TrainingStatus getTrainingStatus(); + + /** + * @return the compression algorithm kind this trainer supports + */ + CompressionDictionary.Kind kind(); + + /** + * Determines if this trainer is compatible with the given compression parameters. + * This method allows the trainer to decide whether it can continue operating + * with new compression parameters or if a new trainer instance is needed. + * + * @param newParams the new compression parameters to check compatibility against + * @return true if this trainer is compatible with the new parameters, false otherwise + */ + boolean isCompatibleWith(CompressionParams newParams); + + /** + * Sets the listener for dictionary training events. + * + * @param listener the listener to be notified when dictionaries are trained, null to remove listener + */ + void setDictionaryTrainedListener(Consumer listener); + + /** + * Updates the sampling rate for this trainer. + * + * @param newSamplingRate the new sampling rate. For exmaple, 1 = sample every time (100%), + * 2 = expect sample 1/2 of data (50%), n = expect sample 1/n of data + */ + void updateSamplingRate(int newSamplingRate); + + /** + * Factory method to create appropriate trainer based on compression parameters. + * + * @param keyspaceName the keyspace name for logging + * @param tableName the table name for logging + * @param params the compression parameters + * @param config the training configuration + * @return a dictionary trainer for the specified compression algorithm + * @throws IllegalArgumentException if no dictionary trainer is available for the compression algorithm + */ + static ICompressionDictionaryTrainer create(String keyspaceName, + String tableName, + CompressionParams params, + CompressionDictionaryTrainingConfig config) + { + ICompressor compressor = params.getSstableCompressor(); + if (!(compressor instanceof IDictionaryCompressor)) + throw new IllegalArgumentException("Compressor does not support dictionary training: " + params.getSstableCompressor()); + + IDictionaryCompressor dictionaryCompressor = (IDictionaryCompressor) compressor; + if (CompressionDictionary.Kind.ACCEPTABLE_DICTIONARY_KINDS.contains(dictionaryCompressor.acceptableDictionaryKind())) + return dictionaryCompressor.acceptableDictionaryKind().getTrainer(keyspaceName, tableName, config, compressor); + + throw new IllegalArgumentException("No dictionary trainer available for: " + params.getSstableCompressor()); + } + + enum TrainingStatus + { + NOT_STARTED, + SAMPLING, + TRAINING, + COMPLETED, + FAILED; + } +} diff --git a/src/java/org/apache/cassandra/db/compression/ManualTrainingOptions.java b/src/java/org/apache/cassandra/db/compression/ManualTrainingOptions.java new file mode 100644 index 000000000000..f783b70810d5 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/ManualTrainingOptions.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.Map; + +/** + * Configuration options for manual compression dictionary training. + * This class encapsulates the parsed and validated parameters needed for training. + */ +public class ManualTrainingOptions +{ + private final int maxSamplingDurationSeconds; + + public ManualTrainingOptions(int maxSamplingDurationSeconds) + { + if (maxSamplingDurationSeconds <= 0) + { + throw new IllegalArgumentException("maxSamplingDurationSeconds must be positive, got: " + maxSamplingDurationSeconds); + } + this.maxSamplingDurationSeconds = maxSamplingDurationSeconds; + } + + /** + * Parse options from a string map, typically from JMX/MBean calls. + * + * @param options the string map containing training options + * @return parsed and validated ManualTrainingOptions + * @throws IllegalArgumentException if required parameters are missing or invalid + */ + public static ManualTrainingOptions fromStringMap(Map options) + { + if (options == null || !options.containsKey("maxSamplingDurationSeconds")) + { + throw new IllegalArgumentException("maxSamplingDurationSeconds parameter is required for manual dictionary training"); + } + + String durationStr = options.get("maxSamplingDurationSeconds"); + int maxSamplingDurationSeconds; + try + { + maxSamplingDurationSeconds = Integer.parseInt(durationStr); + } + catch (NumberFormatException e) + { + throw new IllegalArgumentException("Invalid maxSamplingDurationSeconds value: " + durationStr, e); + } + + return new ManualTrainingOptions(maxSamplingDurationSeconds); + } + + public int getMaxSamplingDurationSeconds() + { + return maxSamplingDurationSeconds; + } + + @Override + public String toString() + { + return "ManualTrainingOptions{" + + "maxSamplingDurationSeconds=" + maxSamplingDurationSeconds + + '}'; + } +} \ No newline at end of file diff --git a/src/java/org/apache/cassandra/db/compression/ZstdCompressionDictionary.java b/src/java/org/apache/cassandra/db/compression/ZstdCompressionDictionary.java new file mode 100644 index 000000000000..de141858e852 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/ZstdCompressionDictionary.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.github.luben.zstd.ZstdDictCompress; +import com.github.luben.zstd.ZstdDictDecompress; +import org.apache.cassandra.io.compress.ZstdCompressorBase; +import org.apache.cassandra.utils.concurrent.Ref; +import org.apache.cassandra.utils.concurrent.RefCounted; +import org.apache.cassandra.utils.concurrent.SelfRefCounted; + +public class ZstdCompressionDictionary implements CompressionDictionary, SelfRefCounted +{ + private static final Logger logger = LoggerFactory.getLogger(ZstdCompressionDictionary.class); + + private final DictId dictId; + private final byte[] rawDictionary; + // One ZstdDictDecompress and multiple ZstdDictCompress (per level) can be derived from the same raw dictionary content + private final ConcurrentHashMap zstdDictCompressPerLevel = new ConcurrentHashMap<>(); + private volatile ZstdDictDecompress dictDecompress; + private final AtomicBoolean closed = new AtomicBoolean(false); + private final Ref selfRef; + + public ZstdCompressionDictionary(DictId dictId, byte[] rawDictionary) + { + this.dictId = dictId; + this.rawDictionary = rawDictionary; + this.selfRef = new Ref<>(this, new Tidy(zstdDictCompressPerLevel, dictDecompress)); + } + + @Override + public DictId identifier() + { + return dictId; + } + + @Override + public Kind kind() + { + return Kind.ZSTD; + } + + @Override + public byte[] rawDictionary() + { + return rawDictionary; + } + + @Override + public boolean equals(Object o) + { + if (!(o instanceof ZstdCompressionDictionary)) return false; + ZstdCompressionDictionary that = (ZstdCompressionDictionary) o; + return Objects.equals(dictId, that.dictId); + } + + @Override + public int hashCode() + { + return dictId.hashCode(); + } + + /** + * Get a pre-processed compression tables that is optimized for compression. + * It is derived/computed from dictionary bytes. + * The internal data structure is different from the tables for decompression. + * + * @param compressionLevel compression level to create the compression table + * @return ZstdDictCompress + */ + public ZstdDictCompress dictionaryForCompression(int compressionLevel) + { + if (closed.get()) + throw new IllegalStateException("Dictionary has been closed"); + + ZstdCompressorBase.validateCompressionLevel(compressionLevel); + + return zstdDictCompressPerLevel.computeIfAbsent(compressionLevel, level -> { + if (closed.get()) + throw new IllegalStateException("Dictionary has been closed"); + return new ZstdDictCompress(rawDictionary, level); + }); + } + + /** + * Get a pre-processed decompression tables that is optimized for decompression. + * It is derived/computed from dictionary bytes. + * The internal data structure is different from the tables for compression. + * + * @return ZstdDictDecompress + */ + public ZstdDictDecompress dictionaryForDecompression() + { + if (closed.get()) + throw new IllegalStateException("Dictionary has been closed"); + + ZstdDictDecompress result = dictDecompress; + if (result != null) + return result; + + synchronized (this) + { + if (closed.get()) + throw new IllegalStateException("Dictionary has been closed"); + + result = dictDecompress; + if (result == null) + { + result = new ZstdDictDecompress(rawDictionary); + dictDecompress = result; + } + return result; + } + } + + @Override + public Ref tryRef() + { + return selfRef.tryRef(); + } + + @Override + public Ref selfRef() + { + return selfRef; + } + + @Override + public Ref ref() + { + return selfRef.ref(); + } + + @Override + public void close() + { + if (closed.compareAndSet(false, true)) + { + selfRef.release(); + } + } + + private static class Tidy implements RefCounted.Tidy + { + private final ConcurrentHashMap zstdDictCompressPerLevel; + private volatile ZstdDictDecompress dictDecompress; + + Tidy(ConcurrentHashMap zstdDictCompressPerLevel, ZstdDictDecompress dictDecompress) + { + this.zstdDictCompressPerLevel = zstdDictCompressPerLevel; + this.dictDecompress = dictDecompress; + } + + @Override + public void tidy() + { + // Close all compression dictionaries + for (ZstdDictCompress compressDict : zstdDictCompressPerLevel.values()) + { + try + { + compressDict.close(); + } + catch (Exception e) + { + // Log but don't fail - continue closing other resources + logger.warn("Failed to close ZstdDictCompress", e); + } + } + zstdDictCompressPerLevel.clear(); + + // Close decompression dictionary + ZstdDictDecompress decompressDict = dictDecompress; + if (decompressDict != null) + { + try + { + decompressDict.close(); + } + catch (Exception e) + { + logger.warn("Failed to close ZstdDictDecompress", e); + } + dictDecompress = null; + } + } + + @Override + public String name() + { + return "ZstdCompressionDictionary"; + } + } +} diff --git a/src/java/org/apache/cassandra/db/compression/ZstdDictionaryTrainer.java b/src/java/org/apache/cassandra/db/compression/ZstdDictionaryTrainer.java new file mode 100644 index 000000000000..fd84a142aa29 --- /dev/null +++ b/src/java/org/apache/cassandra/db/compression/ZstdDictionaryTrainer.java @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.nio.ByteBuffer; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.github.luben.zstd.Zstd; +import com.github.luben.zstd.ZstdDictTrainer; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import org.apache.cassandra.io.compress.IDictionaryCompressor; +import org.apache.cassandra.io.compress.ZstdDictionaryCompressor; +import org.apache.cassandra.schema.CompressionParams; + +/** + * Zstd implementation of dictionary trainer with lifecycle management. + */ +public class ZstdDictionaryTrainer implements ICompressionDictionaryTrainer +{ + private static final Logger logger = LoggerFactory.getLogger(ZstdDictionaryTrainer.class); + + private final String keyspaceName; + private final String tableName; + private final CompressionDictionaryTrainingConfig config; + private final AtomicLong totalSampleSize; + private final AtomicLong sampleCount; + private final int compressionLevel; // optimal if using the same level for training as when compressing. + + // Sampling rate can be updated during training + private volatile int samplingRate; + + // Minimum number of samples required by ZSTD library + private static final int MIN_SAMPLES_REQUIRED = 10; + + private volatile Consumer dictionaryTrainedListener; + // TODO: manage the samples in this class for auto-train (follow-up). The ZstdDictTrainer cannot be re-used for multiple training runs. + private ZstdDictTrainer zstdTrainer; + private volatile boolean closed = false; + private volatile TrainingStatus currentTrainingStatus; + + public ZstdDictionaryTrainer(String keyspaceName, String tableName, + CompressionDictionaryTrainingConfig config, + int compressionLevel) + { + this.keyspaceName = keyspaceName; + this.tableName = tableName; + this.config = config; + this.totalSampleSize = new AtomicLong(0); + this.sampleCount = new AtomicLong(0); + this.compressionLevel = compressionLevel; + this.samplingRate = config.samplingRate; + this.currentTrainingStatus = TrainingStatus.NOT_STARTED; + } + + @Override + public boolean shouldSample() + { + return zstdTrainer != null && ThreadLocalRandom.current().nextInt(samplingRate) == 0; + } + + @Override + public void addSample(ByteBuffer sample) + { + if (closed || sample == null || !sample.hasRemaining() || zstdTrainer == null) + return; + + byte[] sampleBytes = new byte[sample.remaining()]; + sample.duplicate().get(sampleBytes); + + if (zstdTrainer.addSample(sampleBytes)) + { + // Update the totalSampleSize and sampleCount if the sample is added + totalSampleSize.addAndGet(sampleBytes.length); + sampleCount.incrementAndGet(); + } + } + + @Override + public CompressionDictionary trainDictionary(boolean force) + { + boolean isReady = isReady(); + if (!force && !isReady) + { + currentTrainingStatus = TrainingStatus.FAILED; + throw new IllegalStateException("Trainer is not ready"); + } + + long currentSampleCount = sampleCount.get(); + if (currentSampleCount < MIN_SAMPLES_REQUIRED) // minimum samples should be required even if force training + { + currentTrainingStatus = TrainingStatus.FAILED; + String errorMsg = String.format("Insufficient samples for training: %d (minimum required: %d)", + currentSampleCount, MIN_SAMPLES_REQUIRED); + throw new IllegalStateException(errorMsg); + } + + currentTrainingStatus = TrainingStatus.TRAINING; + try + { + logger.debug("Training with sample count: {}, sample size: {}, isReady: {}", + currentSampleCount, totalSampleSize.get(), isReady); + byte[] dictBytes = zstdTrainer.trainSamples(); + long zstdDictId = Zstd.getDictIdFromDict(dictBytes); + DictId dictId = new DictId(Kind.ZSTD, DictId.makeDictId(System.currentTimeMillis(), zstdDictId)); + currentTrainingStatus = TrainingStatus.COMPLETED; + logger.debug("New dictionary is trained with {}", dictId); + CompressionDictionary dictionary = new ZstdCompressionDictionary(dictId, dictBytes); + notifyDictionaryTrainedListener(dictionary); + return dictionary; + } + catch (Exception e) + { + currentTrainingStatus = TrainingStatus.FAILED; + throw new RuntimeException("Failed to train Zstd dictionary", e); + } + } + + @Override + public boolean isReady() + { + return currentTrainingStatus != TrainingStatus.TRAINING + && !closed + && zstdTrainer != null + && totalSampleSize.get() >= config.acceptableTotalSampleSize + && sampleCount.get() > MIN_SAMPLES_REQUIRED; + } + + @Override + public TrainingStatus getTrainingStatus() + { + return currentTrainingStatus; + } + + @Override + public boolean start(boolean manualTraining) + { + if (closed || !(manualTraining || shouldAutoStartTraining())) + return false; + + try + { + // reset on starting; a new zstdTrainer instance is created during reset + reset(); + logger.info("Started dictionary training for {}.{}", keyspaceName, tableName); + currentTrainingStatus = TrainingStatus.SAMPLING; + return true; + } + catch (Exception e) + { + logger.warn("Failed to create ZstdDictTrainer for {}.{}", keyspaceName, tableName, e); + currentTrainingStatus = TrainingStatus.FAILED; + } + return false; + } + + /** + * Determines if training should auto-start based on configuration. + */ + private boolean shouldAutoStartTraining() + { + return DatabaseDescriptor.getCompressionDictionaryTrainingAutoTrainEnabled(); + } + + @Override + public void reset() + { + if (closed) + { + return; + } + + currentTrainingStatus = TrainingStatus.NOT_STARTED; + synchronized (this) + { + totalSampleSize.set(0); + sampleCount.set(0); + zstdTrainer = new ZstdDictTrainer(config.maxTotalSampleSize, config.maxDictionarySize, compressionLevel); + } + } + + @Override + public Kind kind() + { + return Kind.ZSTD; + } + + @Override + public void setDictionaryTrainedListener(Consumer listener) + { + this.dictionaryTrainedListener = listener; + } + + @Override + public void updateSamplingRate(int newSamplingRate) + { + if (newSamplingRate <= 0) + { + throw new IllegalArgumentException("Sampling rate must be positive, got: " + newSamplingRate); + } + this.samplingRate = newSamplingRate; + logger.debug("Updated sampling rate to {} for {}.{}", newSamplingRate, keyspaceName, tableName); + } + + /** + * Notifies the registered listener that a dictionary has been trained. + * + * @param dictionary the newly trained dictionary + */ + private void notifyDictionaryTrainedListener(CompressionDictionary dictionary) + { + Consumer listener = this.dictionaryTrainedListener; + if (listener != null) + { + try + { + listener.accept(dictionary); + } + catch (Exception e) + { + logger.warn("Error notifying dictionary trained listener for {}.{}", keyspaceName, tableName, e); + } + } + } + + @Override + public boolean isCompatibleWith(CompressionParams newParams) + { + if (!newParams.isDictionaryCompressionEnabled()) + { + return false; + } + + IDictionaryCompressor newCompressor = (IDictionaryCompressor) newParams.getSstableCompressor(); + + // Check if the compressor type is compatible with this trainer + if (newCompressor.acceptableDictionaryKind() != Kind.ZSTD) + { + return false; + } + + ZstdDictionaryCompressor zstdDictionaryCompressor = (ZstdDictionaryCompressor) newCompressor; + // For Zstd compressors, check if compression level matches + return this.compressionLevel == zstdDictionaryCompressor.compressionLevel(); + } + + @Override + public void close() + { + if (closed) + return; + + closed = true; + currentTrainingStatus = TrainingStatus.NOT_STARTED; + + synchronized (this) + { + // Permanent shutdown: clear all state and prevent restart + totalSampleSize.set(0); + sampleCount.set(0); + zstdTrainer = null; + } + + logger.info("Permanently closed dictionary trainer for {}.{}", keyspaceName, tableName); + } + + @VisibleForTesting + long getSampleCount() + { + return sampleCount.get(); + } + + @VisibleForTesting + Object trainer() + { + return zstdTrainer; + } +} diff --git a/src/java/org/apache/cassandra/io/compress/CompressedSequentialWriter.java b/src/java/org/apache/cassandra/io/compress/CompressedSequentialWriter.java index ea6448f8182c..1f4cd517b021 100644 --- a/src/java/org/apache/cassandra/io/compress/CompressedSequentialWriter.java +++ b/src/java/org/apache/cassandra/io/compress/CompressedSequentialWriter.java @@ -24,7 +24,10 @@ import java.nio.channels.Channels; import java.util.Optional; import java.util.zip.CRC32; +import javax.annotation.Nullable; +import org.apache.cassandra.db.compression.CompressionDictionary; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.io.FSReadError; import org.apache.cassandra.io.FSWriteError; import org.apache.cassandra.io.sstable.CorruptSSTableException; @@ -61,11 +64,24 @@ public class CompressedSequentialWriter extends SequentialWriter private long uncompressedSize = 0, compressedSize = 0; private final MetadataCollector sstableMetadataCollector; + private final CompressionDictionaryManager compressionDictionaryManager; private final ByteBuffer crcCheckBuffer = ByteBuffer.allocate(4); private final Optional digestFile; private final int maxCompressedLength; + private final boolean isDictionaryEnabled; + + public CompressedSequentialWriter(File file, + File offsetsFile, + File digestFile, + SequentialWriterOption option, + CompressionParams parameters, + MetadataCollector sstableMetadataCollector) + { + this(file, offsetsFile, digestFile, option, parameters, sstableMetadataCollector, null); + } + /** * Create CompressedSequentialWriter without digest file. @@ -74,15 +90,17 @@ public class CompressedSequentialWriter extends SequentialWriter * @param offsetsFile File to write compression metadata * @param digestFile File to write digest * @param option Write option (buffer size and type will be set the same as compression params) - * @param parameters Compression mparameters + * @param parameters Compression parameters * @param sstableMetadataCollector Metadata collector + * @param compressionDictionaryManager manages compression dictionary; null if absent */ public CompressedSequentialWriter(File file, File offsetsFile, File digestFile, SequentialWriterOption option, CompressionParams parameters, - MetadataCollector sstableMetadataCollector) + MetadataCollector sstableMetadataCollector, + @Nullable CompressionDictionaryManager compressionDictionaryManager) { super(file, SequentialWriterOption.newBuilder() .bufferSize(option.bufferSize()) @@ -91,7 +109,7 @@ public CompressedSequentialWriter(File file, .bufferType(parameters.getSstableCompressor().preferredBufferType()) .finishOnClose(option.finishOnClose()) .build()); - this.compressor = parameters.getSstableCompressor(); + ICompressor compressor = parameters.getSstableCompressor(); this.digestFile = Optional.ofNullable(digestFile); // buffer for compression should be the same size as buffer itself @@ -99,8 +117,28 @@ public CompressedSequentialWriter(File file, maxCompressedLength = parameters.maxCompressedLength(); + // Note that we cannot rely on the compressor type to tell whether dictionary compression is enabled. + // Because the `CompressionParams` for this method is updated at the callsite, `DataComponent.buildWriter`. + // See CASSANDRA-15379 for details regarding the optimization. + // Meanwhile, as long as dictionary-based compression is enabled, we want to collect samples. + this.isDictionaryEnabled = compressionDictionaryManager != null && compressionDictionaryManager.isEnabled(); + + CompressionDictionary compressionDictionary = compressionDictionaryManager == null ? null : compressionDictionaryManager.getCurrent(); + if (compressionDictionary != null && compressor instanceof IDictionaryCompressor) + { + compressor = ((IDictionaryCompressor) compressor).getOrCopyWithDictionary(compressionDictionary); + } + else + { + // It is likely on the sstable flushing path and LZ4 compressor or something else is picked. + // In this case, we disable the compression dictionary, i.e. do not attach the dictionary + // bytes to the CompressionInfo component. + compressionDictionary = null; + } + this.compressor = compressor; + this.compressionDictionaryManager = compressionDictionaryManager; /* Index File (-CompressionInfo.db component) and it's header */ - metadataWriter = CompressionMetadata.Writer.open(parameters, offsetsFile); + metadataWriter = CompressionMetadata.Writer.open(parameters, offsetsFile, compressionDictionary); this.sstableMetadataCollector = sstableMetadataCollector; crcMetadata = new ChecksumWriter(new DataOutputStream(Channels.newOutputStream(channel))); @@ -145,6 +183,13 @@ protected void flushData() { // compressing data with buffer re-use buffer.flip(); + + // Collect sample for dictionary training before compression + if (isDictionaryEnabled) + { + compressionDictionaryManager.addSample(buffer.duplicate()); + } + compressed.clear(); compressor.compress(buffer, compressed); } @@ -440,4 +485,4 @@ public CompressedFileWriterMark(long chunkOffset, long uncDataOffset, int validB this.nextChunkIndex = nextChunkIndex; } } -} \ No newline at end of file +} diff --git a/src/java/org/apache/cassandra/io/compress/CompressionMetadata.java b/src/java/org/apache/cassandra/io/compress/CompressionMetadata.java index d5f5f05655e9..cd238cd0f5c8 100644 --- a/src/java/org/apache/cassandra/io/compress/CompressionMetadata.java +++ b/src/java/org/apache/cassandra/io/compress/CompressionMetadata.java @@ -27,11 +27,14 @@ import java.util.Map; import java.util.SortedSet; import java.util.TreeSet; +import javax.annotation.Nullable; import com.google.common.annotations.VisibleForTesting; import com.google.common.primitives.Longs; import org.apache.cassandra.db.TypeSizes; +import org.apache.cassandra.db.compression.CompressionDictionary; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.exceptions.ConfigurationException; import org.apache.cassandra.io.FSReadError; import org.apache.cassandra.io.FSWriteError; @@ -65,13 +68,28 @@ public class CompressionMetadata extends WrappedSharedCloseable private final long chunkOffsetsSize; public final File chunksIndexFile; public final CompressionParams parameters; + @Nullable // null when no dictionary + private final CompressionDictionary compressionDictionary; + private volatile ICompressor resolvedCompressor; @VisibleForTesting - public static CompressionMetadata open(File chunksIndexFile, long compressedLength, boolean hasMaxCompressedSize) + public static CompressionMetadata open(File chunksIndexFile, + long compressedLength, + boolean hasMaxCompressedSize) + { + return open(chunksIndexFile, compressedLength, hasMaxCompressedSize, null); + } + + @VisibleForTesting + public static CompressionMetadata open(File chunksIndexFile, + long compressedLength, + boolean hasMaxCompressedSize, + @Nullable CompressionDictionaryManager compressionDictionaryManager) { CompressionParams parameters; long dataLength; Memory chunkOffsets; + CompressionDictionary compressionDictionary; try (FileInputStreamPlus stream = chunksIndexFile.newInputStream()) { @@ -99,6 +117,7 @@ public static CompressionMetadata open(File chunksIndexFile, long compressedLeng dataLength = stream.readLong(); chunkOffsets = readChunkOffsets(stream); + compressionDictionary = CompressionDictionary.deserialize(stream, compressionDictionaryManager); } catch (FileNotFoundException | NoSuchFileException e) { @@ -109,7 +128,9 @@ public static CompressionMetadata open(File chunksIndexFile, long compressedLeng throw new CorruptSSTableException(e, chunksIndexFile); } - return new CompressionMetadata(chunksIndexFile, parameters, chunkOffsets, chunkOffsets.size(), dataLength, compressedLength); + return new CompressionMetadata(chunksIndexFile, parameters, + chunkOffsets, chunkOffsets.size(), dataLength, + compressedLength, compressionDictionary); } // do not call this constructor directly, unless used in testing @@ -119,7 +140,8 @@ public CompressionMetadata(File chunksIndexFile, Memory chunkOffsets, long chunkOffsetsSize, long dataLength, - long compressedFileLength) + long compressedFileLength, + CompressionDictionary compressionDictionary) { super(chunkOffsets); this.chunksIndexFile = chunksIndexFile; @@ -128,6 +150,7 @@ public CompressionMetadata(File chunksIndexFile, this.compressedFileLength = compressedFileLength; this.chunkOffsets = chunkOffsets; this.chunkOffsetsSize = chunkOffsetsSize; + this.compressionDictionary = compressionDictionary; } private CompressionMetadata(CompressionMetadata copy) @@ -139,11 +162,47 @@ private CompressionMetadata(CompressionMetadata copy) this.compressedFileLength = copy.compressedFileLength; this.chunkOffsets = copy.chunkOffsets; this.chunkOffsetsSize = copy.chunkOffsetsSize; + this.compressionDictionary = copy.compressionDictionary; } public ICompressor compressor() { - return parameters.getSstableCompressor(); + // classic double-checked locking to call resolveCompressor method just once per CompressionMetadata object + ICompressor result = resolvedCompressor; + if (result != null) + return result; + + synchronized (this) + { + result = resolvedCompressor; + if (result == null) + { + result = resolveCompressor(parameters.getSstableCompressor(), compressionDictionary); + resolvedCompressor = result; + } + return result; + } + } + + static ICompressor resolveCompressor(ICompressor compressor, CompressionDictionary dictionary) + { + if (dictionary == null) + return compressor; + + // When the attached dictionary can be consumed by the current dictionary compressor + if (compressor instanceof IDictionaryCompressor) + { + IDictionaryCompressor dictionaryCompressor = (IDictionaryCompressor) compressor; + if (dictionaryCompressor.canConsumeDictionary(dictionary)) + return dictionaryCompressor.getOrCopyWithDictionary(dictionary); + } + + // When the current compressor is not compatible with the dictionary. It could happen in the read path when: + // 1. The current compressor is not a dictionary compressor, but there is dictionary attached + // 2. The current dictionary compressor is a different type, e.g. table schema is changed + // In those cases, we should get the compatible dictionary compressor based on the dictionary + + return dictionary.getCompressor(); } public int chunkLength() @@ -349,16 +408,21 @@ public static class Writer extends Transactional.AbstractTransactional implement // provided by user when setDescriptor private long dataLength, chunkCount; + @Nullable + private CompressionDictionary compressionDictionary; - private Writer(CompressionParams parameters, File file) + private Writer(CompressionParams parameters, File file, CompressionDictionary compressionDictionary) { this.parameters = parameters; this.file = file; + this.compressionDictionary = compressionDictionary; } - public static Writer open(CompressionParams parameters, File file) + public static Writer open(CompressionParams parameters, + File file, + CompressionDictionary compressionDictionary) { - return new Writer(parameters, file); + return new Writer(parameters, file, compressionDictionary); } public void addOffset(long offset) @@ -397,6 +461,21 @@ private void writeHeader(DataOutput out, long dataLength, int chunks) } } + private void writeCompressionDictionary(DataOutput out) + { + if (compressionDictionary == null) + return; + + try + { + compressionDictionary.serialize(out); + } + catch (IOException e) + { + throw new FSWriteError(e, file); + } + } + // we've written everything; wire up some final metadata state public Writer finalizeLength(long dataLength, int chunkCount) { @@ -426,6 +505,7 @@ public void doPrepare() for (int i = 0; i < count; i++) out.writeLong(offsets.getLong(i * 8L)); + writeCompressionDictionary(out); out.flush(); out.sync(); } @@ -453,7 +533,9 @@ public CompressionMetadata open(long dataLength, long compressedLength) if (tCount < this.count) compressedLength = tOffsets.getLong(tCount * 8L); - return new CompressionMetadata(file, parameters, tOffsets, tCount * 8L, dataLength, compressedLength); + return new CompressionMetadata(file, parameters, + tOffsets, tCount * 8L, dataLength, + compressedLength, compressionDictionary); } /** diff --git a/src/java/org/apache/cassandra/io/compress/ICompressor.java b/src/java/org/apache/cassandra/io/compress/ICompressor.java index fd6a104431b3..950ae03e3de4 100644 --- a/src/java/org/apache/cassandra/io/compress/ICompressor.java +++ b/src/java/org/apache/cassandra/io/compress/ICompressor.java @@ -37,6 +37,11 @@ enum Uses { FAST_COMPRESSION } + /** + * Get the maximum compressed size in the worst case scenario + * @param chunkLength input data (chunk) size + * @return compressed size upper bound in the worse case + */ public int initialCompressedBufferLength(int chunkLength); public int uncompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) throws IOException; diff --git a/src/java/org/apache/cassandra/io/compress/IDictionaryCompressor.java b/src/java/org/apache/cassandra/io/compress/IDictionaryCompressor.java new file mode 100644 index 000000000000..1d919f18b5cd --- /dev/null +++ b/src/java/org/apache/cassandra/io/compress/IDictionaryCompressor.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.io.compress; + +import org.apache.cassandra.db.compression.CompressionDictionary; + +/** + * Interface for compressors that support dictionary-based compression. + *
+ * Dictionary compressors can use pre-trained compression dictionaries to achieve + * better compression ratios, especially for small data chunks that are similar + * to the training data used to create the dictionary. + * + * @param the specific type of compression dictionary this compressor supports + */ +public interface IDictionaryCompressor extends ICompressor +{ + /** + * Returns a compressor instance configured with the specified compression dictionary. + *
+ * This method may return the same instance if it already uses the given dictionary, + * or create a new instance configured with the dictionary. The implementation should + * be efficient and avoid unnecessary object creation when possible. + * + * @param compressionDictionary the dictionary to use for compression/decompression + * @return a compressor instance that will use the specified dictionary + */ + ICompressor getOrCopyWithDictionary(T compressionDictionary); + + /** + * Returns the kind of compression dictionary that this compressor can accept. + *
+ * This is used to validate dictionary compatibility before attempting to use + * a dictionary with this compressor. Only dictionaries of the returned kind + * should be passed to {@link #getOrCopyWithDictionary(CompressionDictionary)}. + * + * @return the compression dictionary kind supported by this compressor + */ + CompressionDictionary.Kind acceptableDictionaryKind(); + + /** + * Checks whether this compressor can use the given compression dictionary. + *
+ * The default implementation compares the dictionary's kind with the kind + * returned by {@link #acceptableDictionaryKind()}. Compressor implementations + * may override this method to provide more sophisticated compatibility checks. + * + * @param dictionary the compression dictionary to check for compatibility + * @return true if this compressor can use the dictionary, false otherwise + */ + default boolean canConsumeDictionary(CompressionDictionary dictionary) + { + return dictionary.kind() == acceptableDictionaryKind(); + } +} diff --git a/src/java/org/apache/cassandra/io/compress/ZstdCompressor.java b/src/java/org/apache/cassandra/io/compress/ZstdCompressor.java index c86db26c8621..9327de4c139c 100644 --- a/src/java/org/apache/cassandra/io/compress/ZstdCompressor.java +++ b/src/java/org/apache/cassandra/io/compress/ZstdCompressor.java @@ -18,72 +18,31 @@ package org.apache.cassandra.io.compress; -import java.io.IOException; -import java.nio.ByteBuffer; import java.util.Collections; -import java.util.HashSet; import java.util.Map; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableSet; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.github.luben.zstd.Zstd; - /** * ZSTD Compressor */ -public class ZstdCompressor implements ICompressor +public class ZstdCompressor extends ZstdCompressorBase implements ICompressor { - private static final Logger logger = LoggerFactory.getLogger(ZstdCompressor.class); - - // These might change with the version of Zstd we're using - public static final int FAST_COMPRESSION_LEVEL = Zstd.minCompressionLevel(); - public static final int BEST_COMPRESSION_LEVEL = Zstd.maxCompressionLevel(); - - // Compressor Defaults - public static final int DEFAULT_COMPRESSION_LEVEL = 3; - private static final boolean ENABLE_CHECKSUM_FLAG = true; - - @VisibleForTesting - public static final String COMPRESSION_LEVEL_OPTION_NAME = "compression_level"; - private static final ConcurrentHashMap instances = new ConcurrentHashMap<>(); - private final int compressionLevel; - private final Set recommendedUses; - /** * Create a Zstd compressor with the given options + * Invoked by {@link org.apache.cassandra.schema.CompressionParams#createCompressor} via reflection * - * @param options - * @return + * @param options compression options + * @return ZstdCompressor */ public static ZstdCompressor create(Map options) { int level = getOrDefaultCompressionLevel(options); - - if (!isValid(level)) - throw new IllegalArgumentException(String.format("%s=%d is invalid", COMPRESSION_LEVEL_OPTION_NAME, level)); - + validateCompressionLevel(level); return getOrCreate(level); } - /** - * Private constructor - * - * @param compressionLevel - */ - private ZstdCompressor(int compressionLevel) - { - this.compressionLevel = compressionLevel; - this.recommendedUses = ImmutableSet.of(Uses.GENERAL); - logger.trace("Creating Zstd Compressor with compression level={}", compressionLevel); - } - /** * Get a cached instance or return a new one * @@ -92,157 +51,16 @@ private ZstdCompressor(int compressionLevel) */ public static ZstdCompressor getOrCreate(int level) { - return instances.computeIfAbsent(level, l -> new ZstdCompressor(level)); - } - - /** - * Get initial compressed buffer length - * - * @param chunkLength - * @return - */ - @Override - public int initialCompressedBufferLength(int chunkLength) - { - return (int) Zstd.compressBound(chunkLength); - } - - /** - * Decompress data using arrays - * - * @param input - * @param inputOffset - * @param inputLength - * @param output - * @param outputOffset - * @return - * @throws IOException - */ - @Override - public int uncompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) - throws IOException - { - long dsz = Zstd.decompressByteArray(output, outputOffset, output.length - outputOffset, - input, inputOffset, inputLength); - - if (Zstd.isError(dsz)) - throw new IOException(String.format("Decompression failed due to %s", Zstd.getErrorName(dsz))); - - return (int) dsz; - } - - /** - * Decompress data via ByteBuffers - * - * @param input - * @param output - * @throws IOException - */ - @Override - public void uncompress(ByteBuffer input, ByteBuffer output) throws IOException - { - try - { - Zstd.decompress(output, input); - } catch (Exception e) - { - throw new IOException("Decompression failed", e); - } - } - - /** - * Compress using ByteBuffers - * - * @param input - * @param output - * @throws IOException - */ - @Override - public void compress(ByteBuffer input, ByteBuffer output) throws IOException - { - try - { - Zstd.compress(output, input, compressionLevel, ENABLE_CHECKSUM_FLAG); - } catch (Exception e) - { - throw new IOException("Compression failed", e); - } - } - - /** - * Check if the given compression level is valid. This can be a negative value as well. - * - * @param level - * @return - */ - private static boolean isValid(int level) - { - return (level >= FAST_COMPRESSION_LEVEL && level <= BEST_COMPRESSION_LEVEL); + return instances.computeIfAbsent(level, ZstdCompressor::new); } /** - * Parse the compression options - * - * @param options - * @return - */ - private static int getOrDefaultCompressionLevel(Map options) - { - if (options == null) - return DEFAULT_COMPRESSION_LEVEL; - - String val = options.get(COMPRESSION_LEVEL_OPTION_NAME); - - if (val == null) - return DEFAULT_COMPRESSION_LEVEL; - - return Integer.valueOf(val); - } - - /** - * Return the preferred BufferType - * - * @return - */ - @Override - public BufferType preferredBufferType() - { - return BufferType.OFF_HEAP; - } - - /** - * Check whether the given BufferType is supported - * - * @param bufferType - * @return - */ - @Override - public boolean supports(BufferType bufferType) - { - return bufferType == BufferType.OFF_HEAP; - } - - /** - * Lists the supported options by this compressor + * Private constructor * - * @return + * @param compressionLevel */ - @Override - public Set supportedOptions() - { - return new HashSet<>(Collections.singletonList(COMPRESSION_LEVEL_OPTION_NAME)); - } - - - @VisibleForTesting - public int getCompressionLevel() - { - return compressionLevel; - } - - @Override - public Set recommendedUses() + private ZstdCompressor(int compressionLevel) { - return recommendedUses; + super(compressionLevel, Collections.singleton(COMPRESSION_LEVEL_OPTION_NAME)); } } diff --git a/src/java/org/apache/cassandra/io/compress/ZstdCompressorBase.java b/src/java/org/apache/cassandra/io/compress/ZstdCompressorBase.java new file mode 100644 index 000000000000..bbe278cbf43c --- /dev/null +++ b/src/java/org/apache/cassandra/io/compress/ZstdCompressorBase.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.io.compress; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.Map; +import java.util.Set; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.github.luben.zstd.Zstd; + +public abstract class ZstdCompressorBase implements ICompressor +{ + // These might change with the version of Zstd we're using + public static final int FAST_COMPRESSION_LEVEL = Zstd.minCompressionLevel(); + public static final int BEST_COMPRESSION_LEVEL = Zstd.maxCompressionLevel(); + + // Compressor Defaults + public static final int DEFAULT_COMPRESSION_LEVEL = 3; + public static final boolean ENABLE_CHECKSUM_FLAG = true; + + // Compressor option names + public static final String COMPRESSION_LEVEL_OPTION_NAME = "compression_level"; + + protected final Logger logger = LoggerFactory.getLogger(getClass()); + + private final int compressionLevel; + private final Set recommendedUses; + private final Set supportedOptions; + + protected ZstdCompressorBase(int compressionLevel, Set supportedOptions) + { + this.compressionLevel = compressionLevel; + this.supportedOptions = Collections.unmodifiableSet(supportedOptions); + this.recommendedUses = Set.of(ICompressor.Uses.GENERAL); + logger.trace("Creating Zstd Compressor with compression level={}", compressionLevel); + } + + @Override + public int initialCompressedBufferLength(int chunkLength) + { + return (int) Zstd.compressBound(chunkLength); + } + + @Override + public BufferType preferredBufferType() + { + return BufferType.OFF_HEAP; + } + + @Override + public boolean supports(BufferType bufferType) + { + return bufferType == BufferType.OFF_HEAP; + } + + @Override + public Set recommendedUses() + { + return recommendedUses; + } + + @VisibleForTesting + public int compressionLevel() + { + return compressionLevel; + } + + @Override + public Set supportedOptions() + { + return supportedOptions; + } + + /** + * Decompress data using arrays + * + * @param input + * @param inputOffset + * @param inputLength + * @param output + * @param outputOffset + * @return + * @throws IOException + */ + @Override + public int uncompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) + throws IOException + { + long dsz; + try + { + dsz = Zstd.decompressByteArray(output, outputOffset, output.length - outputOffset, + input, inputOffset, inputLength); + } + catch (Exception e) + { + throw new IOException("Decompression failed", e); + } + + if (Zstd.isError(dsz)) + throw new IOException("Decompression failed due to " + Zstd.getErrorName(dsz)); + + return (int) dsz; + } + + /** + * Decompress data via ByteBuffers + * + * @param input + * @param output + * @throws IOException + */ + @Override + public void uncompress(ByteBuffer input, ByteBuffer output) throws IOException + { + try + { + Zstd.decompress(output, input); + } catch (Exception e) + { + throw new IOException("Decompression failed", e); + } + } + + /** + * Compress using ByteBuffers + * + * @param input + * @param output + * @throws IOException + */ + @Override + public void compress(ByteBuffer input, ByteBuffer output) throws IOException + { + try + { + Zstd.compress(output, input, compressionLevel(), ENABLE_CHECKSUM_FLAG); + } catch (Exception e) + { + throw new IOException("Compression failed", e); + } + } + + /** + * Check if the given compression level is valid. This can be a negative value as well. + * + * @param level compression level + */ + public static void validateCompressionLevel(int level) + { + if (level < FAST_COMPRESSION_LEVEL || level > BEST_COMPRESSION_LEVEL) + { + throw new IllegalArgumentException(String.format("%s=%d is invalid", COMPRESSION_LEVEL_OPTION_NAME, level)); + } + } + + /** + * Get the supplied compression level; otherwise, use the default + * + * @param options compression options + * @return compression level + */ + public static int getOrDefaultCompressionLevel(Map options) + { + if (options == null) + return DEFAULT_COMPRESSION_LEVEL; + + String val = options.get(COMPRESSION_LEVEL_OPTION_NAME); + + if (val == null) + return DEFAULT_COMPRESSION_LEVEL; + + return Integer.parseInt(val); + } +} diff --git a/src/java/org/apache/cassandra/io/compress/ZstdDictionaryCompressor.java b/src/java/org/apache/cassandra/io/compress/ZstdDictionaryCompressor.java new file mode 100644 index 000000000000..d503f0d33261 --- /dev/null +++ b/src/java/org/apache/cassandra/io/compress/ZstdDictionaryCompressor.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.io.compress; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Caffeine; +import com.github.benmanes.caffeine.cache.RemovalCause; +import com.github.luben.zstd.Zstd; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.compression.ZstdCompressionDictionary; +import org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import org.apache.cassandra.utils.concurrent.Ref; + +import javax.annotation.Nullable; + +public class ZstdDictionaryCompressor extends ZstdCompressorBase implements IDictionaryCompressor +{ + private static final ConcurrentHashMap instancesPerLevel = new ConcurrentHashMap<>(); + private static final Cache instancePerDict = + Caffeine.newBuilder() + .maximumSize(DatabaseDescriptor.getCompressionDictionaryCacheSize()) + .expireAfterAccess(Duration.ofSeconds(DatabaseDescriptor.getCompressionDictionaryCacheExpireSeconds())) + .removalListener((ZstdCompressionDictionary dictionary, + ZstdDictionaryCompressor compressor, + RemovalCause cause) -> { + // Release dictionary reference when compressor is evicted from cache + if (compressor != null && compressor.dictionaryRef != null) + { + compressor.dictionaryRef.release(); + } + }) + .build(); + + // dictioanry and its ref are null, when they are absent. + // In this case, the compressor falls back to be the same as ZstdCompressor + @Nullable + private final ZstdCompressionDictionary dictionary; + @Nullable + private final Ref dictionaryRef; + + /** + * Create a ZstdDictionaryCompressor with the given options + * Invoked by {@link org.apache.cassandra.schema.CompressionParams#createCompressor} via reflection + * + * @param options compression options + * @return ZstdDictionaryCompressor + */ + public static ZstdDictionaryCompressor create(Map options) + { + int level = getOrDefaultCompressionLevel(options); + validateCompressionLevel(level); + return getOrCreate(level, null); + } + + // Constructor used to create the compressor for reading the sstable; the compression level is not relevant + public static ZstdDictionaryCompressor create(ZstdCompressionDictionary dictionary) + { + return getOrCreate(DEFAULT_COMPRESSION_LEVEL, dictionary); + } + + private static ZstdDictionaryCompressor getOrCreate(int level, ZstdCompressionDictionary dictionary) + { + if (dictionary == null) + { + return instancesPerLevel.computeIfAbsent(level, ZstdDictionaryCompressor::new); + } + + return instancePerDict.get(dictionary, dict -> { + // Get a reference to the dictionary when creating new compressor + Ref ref = dict != null ? dict.tryRef() : null; + if (ref == null && dict != null) + { + // Dictionary is being closed, cannot create compressor + throw new IllegalStateException("Dictionary is being closed"); + } + return new ZstdDictionaryCompressor(level, dictionary, ref); + }); + } + + private ZstdDictionaryCompressor(int level) + { + this(level, null, null); + } + + private ZstdDictionaryCompressor(int level, ZstdCompressionDictionary dictionary, Ref dictionaryRef) + { + super(level, Set.of(COMPRESSION_LEVEL_OPTION_NAME)); + this.dictionary = dictionary; + this.dictionaryRef = dictionaryRef; + } + + @Override + public ZstdDictionaryCompressor getOrCopyWithDictionary(ZstdCompressionDictionary compressionDictionary) + { + return getOrCreate(compressionLevel(), compressionDictionary); + } + + @Override + public Kind acceptableDictionaryKind() + { + return Kind.ZSTD; + } + + @Override + public int uncompress(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) throws IOException + { + // fallback to non-dict zstd compressor + if (dictionary == null) + { + return super.uncompress(input, inputOffset, inputLength, output, outputOffset); + } + + int dsz; + try + { + dsz = (int) Zstd.decompressFastDict(output, outputOffset, + input, inputOffset, inputLength, + dictionary.dictionaryForDecompression()); + } + catch (Exception e) + { + throw new IOException("Decompression failed", e); + } + + if (Zstd.isError(dsz)) + throw new IOException("Decompression failed due to " + Zstd.getErrorName(dsz)); + + return dsz; + } + + @Override + public void uncompress(ByteBuffer input, ByteBuffer output) throws IOException + { + if (dictionary == null) + { + super.uncompress(input, output); + return; + } + + try + { + // Zstd compressors expect only direct bytebuffer. See ZstdCompressorBase.preferredBufferType and supports + int decompressedSize = (int) Zstd.decompressDirectByteBufferFastDict(output, output.position(), output.limit() - output.position(), + input, input.position(), input.limit() - input.position(), + dictionary.dictionaryForDecompression()); + output.position(output.position() + decompressedSize); + input.position(input.limit()); + } + catch (Exception e) + { + throw new IOException("Decompression failed", e); + } + } + + @Override + public void compress(ByteBuffer input, ByteBuffer output) throws IOException + { + if (dictionary == null) + { + super.compress(input, output); + return; + } + + try + { + // Zstd compressors expect only direct bytebuffer. See ZstdCompressorBase.preferredBufferType and supports + int compressedSize = (int) Zstd.compressDirectByteBufferFastDict(output, output.position(), output.limit() - output.position(), + input, input.position(), input.limit() - input.position(), + dictionary.dictionaryForCompression(compressionLevel())); + output.position(output.position() + compressedSize); + input.position(input.limit()); + } + catch (Exception e) + { + throw new IOException("Compression failed", e); + } + } + + @VisibleForTesting + ZstdCompressionDictionary dictionary() + { + return dictionary; + } + + @VisibleForTesting + public static void invalidateCache() + { + instancePerDict.invalidateAll(); + } +} diff --git a/src/java/org/apache/cassandra/io/sstable/SSTable.java b/src/java/org/apache/cassandra/io/sstable/SSTable.java index 14c7af6cd5c8..0ee4de1897ec 100644 --- a/src/java/org/apache/cassandra/io/sstable/SSTable.java +++ b/src/java/org/apache/cassandra/io/sstable/SSTable.java @@ -43,6 +43,7 @@ import org.apache.cassandra.dht.AbstractBounds; import org.apache.cassandra.dht.IPartitioner; import org.apache.cassandra.dht.Token; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.io.sstable.format.SSTableFormat; import org.apache.cassandra.io.sstable.format.SSTableFormat.Components; import org.apache.cassandra.io.sstable.format.TOCComponent; @@ -369,6 +370,8 @@ public interface Owner OpOrder.Barrier newReadOrderingBarrier(); TableMetrics getMetrics(); + + CompressionDictionaryManager compressionDictionaryManager(); } /** diff --git a/src/java/org/apache/cassandra/io/sstable/SimpleSSTableMultiWriter.java b/src/java/org/apache/cassandra/io/sstable/SimpleSSTableMultiWriter.java index 6c7b7b6d7fe2..fa541d075c1d 100644 --- a/src/java/org/apache/cassandra/io/sstable/SimpleSSTableMultiWriter.java +++ b/src/java/org/apache/cassandra/io/sstable/SimpleSSTableMultiWriter.java @@ -23,6 +23,7 @@ import org.apache.cassandra.db.SerializationHeader; import org.apache.cassandra.db.commitlog.CommitLogPosition; import org.apache.cassandra.db.commitlog.IntervalSet; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.db.lifecycle.ILifecycleTransaction; import org.apache.cassandra.db.rows.UnfilteredRowIterator; import org.apache.cassandra.index.Index; @@ -122,17 +123,21 @@ public static SSTableMultiWriter create(Descriptor descriptor, MetadataCollector metadataCollector = new MetadataCollector(metadata.get().comparator) .commitLogIntervals(commitLogPositions != null ? commitLogPositions : IntervalSet.empty()) .sstableLevel(sstableLevel); - SSTableWriter writer = descriptor.getFormat().getWriterFactory().builder(descriptor) - .setKeyCount(keyCount) - .setRepairedAt(repairedAt) - .setPendingRepair(pendingRepair) - .setTransientSSTable(isTransient) - .setTableMetadataRef(metadata) - .setMetadataCollector(metadataCollector) - .setSerializationHeader(header) - .addDefaultComponents(indexGroups) - .setSecondaryIndexGroups(indexGroups) - .build(txn, owner); + CompressionDictionaryManager compressionDictionaryManager = owner == null ? null : owner.compressionDictionaryManager(); + SSTableWriter writer = descriptor.getFormat() + .getWriterFactory() + .builder(descriptor) + .setKeyCount(keyCount) + .setRepairedAt(repairedAt) + .setPendingRepair(pendingRepair) + .setTransientSSTable(isTransient) + .setTableMetadataRef(metadata) + .setMetadataCollector(metadataCollector) + .setSerializationHeader(header) + .addDefaultComponents(indexGroups) + .setSecondaryIndexGroups(indexGroups) + .setCompressionDictionaryManager(compressionDictionaryManager) + .build(txn, owner); return new SimpleSSTableMultiWriter(writer, txn); } } diff --git a/src/java/org/apache/cassandra/io/sstable/format/CompressionInfoComponent.java b/src/java/org/apache/cassandra/io/sstable/format/CompressionInfoComponent.java index 0e24fa991d72..abb0658c9d33 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/CompressionInfoComponent.java +++ b/src/java/org/apache/cassandra/io/sstable/format/CompressionInfoComponent.java @@ -22,6 +22,9 @@ import java.nio.file.NoSuchFileException; import java.util.Set; +import javax.annotation.Nullable; + +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.io.FSReadError; import org.apache.cassandra.io.compress.CompressionMetadata; import org.apache.cassandra.io.sstable.Component; @@ -32,27 +35,31 @@ public class CompressionInfoComponent { - public static CompressionMetadata maybeLoad(Descriptor descriptor, Set components) + public static CompressionMetadata maybeLoad(Descriptor descriptor, Set components, + @Nullable CompressionDictionaryManager compressionDictionaryManager) { if (components.contains(Components.COMPRESSION_INFO)) - return load(descriptor); + return load(descriptor, compressionDictionaryManager); return null; } - public static CompressionMetadata loadIfExists(Descriptor descriptor) + public static CompressionMetadata loadIfExists(Descriptor descriptor, + @Nullable CompressionDictionaryManager compressionDictionaryManager) { if (descriptor.fileFor(Components.COMPRESSION_INFO).exists()) - return load(descriptor); + return load(descriptor, compressionDictionaryManager); return null; } - public static CompressionMetadata load(Descriptor descriptor) + public static CompressionMetadata load(Descriptor descriptor, + @Nullable CompressionDictionaryManager compressionDictionaryManager) { return CompressionMetadata.open(descriptor.fileFor(Components.COMPRESSION_INFO), descriptor.fileFor(Components.DATA).length(), - descriptor.version.hasMaxCompressedLength()); + descriptor.version.hasMaxCompressedLength(), + compressionDictionaryManager); } /** diff --git a/src/java/org/apache/cassandra/io/sstable/format/DataComponent.java b/src/java/org/apache/cassandra/io/sstable/format/DataComponent.java index 9367cb444d80..69528628d348 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/DataComponent.java +++ b/src/java/org/apache/cassandra/io/sstable/format/DataComponent.java @@ -20,6 +20,7 @@ import org.apache.cassandra.config.Config.FlushCompression; import org.apache.cassandra.db.compaction.OperationType; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.io.compress.CompressedSequentialWriter; import org.apache.cassandra.io.compress.ICompressor; import org.apache.cassandra.io.sstable.Descriptor; @@ -38,7 +39,8 @@ public static SequentialWriter buildWriter(Descriptor descriptor, SequentialWriterOption options, MetadataCollector metadataCollector, OperationType operationType, - FlushCompression flushCompression) + FlushCompression flushCompression, + CompressionDictionaryManager compressionDictionaryManager) { if (metadata.params.compression.isEnabled()) { @@ -49,7 +51,8 @@ public static SequentialWriter buildWriter(Descriptor descriptor, descriptor.fileFor(Components.DIGEST), options, compressionParams, - metadataCollector); + metadataCollector, + compressionDictionaryManager); } else { diff --git a/src/java/org/apache/cassandra/io/sstable/format/SSTableWriter.java b/src/java/org/apache/cassandra/io/sstable/format/SSTableWriter.java index 113332b10207..1af8b4ea8c1a 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/SSTableWriter.java +++ b/src/java/org/apache/cassandra/io/sstable/format/SSTableWriter.java @@ -28,6 +28,7 @@ import java.util.Set; import java.util.function.Consumer; import java.util.function.Supplier; +import javax.annotation.Nullable; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -37,6 +38,7 @@ import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.SerializationHeader; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.db.lifecycle.ILifecycleTransaction; import org.apache.cassandra.db.rows.UnfilteredRowIterator; import org.apache.cassandra.dht.AbstractBounds; @@ -437,6 +439,8 @@ public abstract static class Builder indexGroups; + @Nullable + private CompressionDictionaryManager compressionDictionaryManager; public B setMetadataCollector(MetadataCollector metadataCollector) { @@ -515,6 +519,18 @@ public B setSecondaryIndexGroups(Collection indexGroups) return (B) this; } + public B setCompressionDictionaryManager(CompressionDictionaryManager compressionDictionaryManager) + { + this.compressionDictionaryManager = compressionDictionaryManager; + return (B) this; + } + + @Nullable + public CompressionDictionaryManager getCompressionDictionaryManager() + { + return compressionDictionaryManager; + } + public MetadataCollector getMetadataCollector() { return metadataCollector; diff --git a/src/java/org/apache/cassandra/io/sstable/format/big/BigFormat.java b/src/java/org/apache/cassandra/io/sstable/format/big/BigFormat.java index e987506c40f8..e6b60c2a0656 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/big/BigFormat.java +++ b/src/java/org/apache/cassandra/io/sstable/format/big/BigFormat.java @@ -433,7 +433,8 @@ public BigTableWriter.Builder builder(Descriptor descriptor) static class BigVersion extends Version { - public static final String current_version = DatabaseDescriptor.getStorageCompatibilityMode().isBefore(5) ? "nb" : "oa"; + public static final String current_version = DatabaseDescriptor.getStorageCompatibilityMode().isBefore(5) ? "nb" : + DatabaseDescriptor.getStorageCompatibilityMode().isBefore(6) ? "oa" : "pa"; public static final String earliest_supported_version = "ma"; // ma (3.0.0): swap bf hash order @@ -448,6 +449,7 @@ static class BigVersion extends Version // oa (5.0): improved min/max, partition level deletion presence marker, key range (CASSANDRA-18134) // Long deletionTime to prevent TTL overflow // token space coverage + // pa (6.0): compression dictionary metadata in CompressionInfo component // // NOTE: When adding a new version: // - Please add it to LegacySSTableTest diff --git a/src/java/org/apache/cassandra/io/sstable/format/big/BigSSTableReaderLoadingBuilder.java b/src/java/org/apache/cassandra/io/sstable/format/big/BigSSTableReaderLoadingBuilder.java index 84e02217d565..3557b0d80227 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/big/BigSSTableReaderLoadingBuilder.java +++ b/src/java/org/apache/cassandra/io/sstable/format/big/BigSSTableReaderLoadingBuilder.java @@ -26,6 +26,7 @@ import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.SerializationHeader; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.io.compress.CompressionMetadata; import org.apache.cassandra.io.sstable.Downsampling; import org.apache.cassandra.io.sstable.KeyReader; @@ -137,7 +138,8 @@ protected void openComponents(BigTableReader.Builder builder, SSTable.Owner owne } } - try (CompressionMetadata compressionMetadata = CompressionInfoComponent.maybeLoad(descriptor, components)) + CompressionDictionaryManager compressionDictionaryManager = owner == null ? null : owner.compressionDictionaryManager(); + try (CompressionMetadata compressionMetadata = CompressionInfoComponent.maybeLoad(descriptor, components, compressionDictionaryManager)) { builder.setDataFile(dataFileBuilder(builder.getStatsMetadata()) .withCompressionMetadata(compressionMetadata) diff --git a/src/java/org/apache/cassandra/io/sstable/format/big/BigTableWriter.java b/src/java/org/apache/cassandra/io/sstable/format/big/BigTableWriter.java index 3233ca4c0633..6cf01e7fbcb6 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/big/BigTableWriter.java +++ b/src/java/org/apache/cassandra/io/sstable/format/big/BigTableWriter.java @@ -389,7 +389,8 @@ protected SequentialWriter openDataWriter() getIOOptions().writerOptions, getMetadataCollector(), ensuringInBuildInternalContext(operationType), - getIOOptions().flushCompression); + getIOOptions().flushCompression, + getCompressionDictionaryManager()); this.dataWriterOpened = true; return dataWriter; } diff --git a/src/java/org/apache/cassandra/io/sstable/format/bti/BtiFormat.java b/src/java/org/apache/cassandra/io/sstable/format/bti/BtiFormat.java index e7703c6a0612..8c6dbf1f1f62 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/bti/BtiFormat.java +++ b/src/java/org/apache/cassandra/io/sstable/format/bti/BtiFormat.java @@ -286,11 +286,12 @@ public long estimateSize(SSTableWriter.SSTableSizeParameters parameters) static class BtiVersion extends Version { - public static final String current_version = "da"; + public static final String current_version = "ea"; public static final String earliest_supported_version = "da"; // versions aa-cz are not supported in OSS - // da (5.0): initial version of the BIT format + // da (5.0): initial version of the BTI format + // ea (6.0): compression dictionary metadata in CompressionInfo component // NOTE: when adding a new version, please add that to LegacySSTableTest, too. private final boolean isLatestVersion; diff --git a/src/java/org/apache/cassandra/io/sstable/format/bti/BtiTableReaderLoadingBuilder.java b/src/java/org/apache/cassandra/io/sstable/format/bti/BtiTableReaderLoadingBuilder.java index fa408adc5d0e..8f47e89a9a96 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/bti/BtiTableReaderLoadingBuilder.java +++ b/src/java/org/apache/cassandra/io/sstable/format/bti/BtiTableReaderLoadingBuilder.java @@ -23,7 +23,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.cassandra.db.ColumnFamilyStore; import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.compression.CompressionDictionaryManager; import org.apache.cassandra.dht.IPartitioner; import org.apache.cassandra.io.compress.CompressionMetadata; import org.apache.cassandra.io.sstable.KeyReader; @@ -38,6 +40,7 @@ import org.apache.cassandra.io.sstable.metadata.ValidationMetadata; import org.apache.cassandra.io.util.FileHandle; import org.apache.cassandra.metrics.TableMetrics; +import org.apache.cassandra.schema.Schema; import org.apache.cassandra.utils.FilterFactory; import org.apache.cassandra.utils.IFilter; import org.apache.cassandra.utils.Throwables; @@ -68,8 +71,15 @@ private KeyReader createKeyReader(StatsMetadata statsMetadata) throws IOExceptio { checkNotNull(statsMetadata); + ColumnFamilyStore cfs = Schema.instance.getColumnFamilyStoreInstance(tableMetadataRef.id); + CompressionDictionaryManager compressionDictionaryManager = null; + if (cfs != null) + { + compressionDictionaryManager = cfs.compressionDictionaryManager(); + } + try (PartitionIndex index = PartitionIndex.load(partitionIndexFileBuilder(), tableMetadataRef.getLocal().partitioner, false); - CompressionMetadata compressionMetadata = CompressionInfoComponent.maybeLoad(descriptor, components); + CompressionMetadata compressionMetadata = CompressionInfoComponent.maybeLoad(descriptor, components, compressionDictionaryManager); FileHandle dFile = dataFileBuilder(statsMetadata).withCompressionMetadata(compressionMetadata) .withCrcCheckChance(() -> tableMetadataRef.getLocal().params.crcCheckChance) .complete(); @@ -131,7 +141,7 @@ protected void openComponents(BtiTableReader.Builder builder, SSTable.Owner owne } } - try (CompressionMetadata compressionMetadata = CompressionInfoComponent.maybeLoad(descriptor, components)) + try (CompressionMetadata compressionMetadata = CompressionInfoComponent.maybeLoad(descriptor, components, owner == null ? null : owner.compressionDictionaryManager())) { builder.setDataFile(dataFileBuilder(builder.getStatsMetadata()) .withCompressionMetadata(compressionMetadata) diff --git a/src/java/org/apache/cassandra/io/sstable/format/bti/BtiTableWriter.java b/src/java/org/apache/cassandra/io/sstable/format/bti/BtiTableWriter.java index 074c5c17085c..0179497f0988 100644 --- a/src/java/org/apache/cassandra/io/sstable/format/bti/BtiTableWriter.java +++ b/src/java/org/apache/cassandra/io/sstable/format/bti/BtiTableWriter.java @@ -334,7 +334,8 @@ protected SequentialWriter openDataWriter() getIOOptions().writerOptions, getMetadataCollector(), ensuringInBuildInternalContext(operationType), - getIOOptions().flushCompression); + getIOOptions().flushCompression, + getCompressionDictionaryManager()); } @Override diff --git a/src/java/org/apache/cassandra/net/MessagingService.java b/src/java/org/apache/cassandra/net/MessagingService.java index a636005beef5..129d5114dd2b 100644 --- a/src/java/org/apache/cassandra/net/MessagingService.java +++ b/src/java/org/apache/cassandra/net/MessagingService.java @@ -478,21 +478,28 @@ public void respond(V response, Message message) public Future sendWithResponse(InetAddressAndPort to, Message msg) { Promise future = AsyncPromise.uncancellable(); - MessagingService.instance().sendWithCallback(msg, to, - new RequestCallback() - { - @Override - public void onResponse(Message msg) - { - future.setSuccess(msg.payload); - } - - @Override - public void onFailure(InetAddressAndPort from, RequestFailure failure) - { - future.setFailure(new RuntimeException(failure.toString())); - } - }); + RequestCallback callback = new RequestCallback() + { + @Override + public void onResponse(Message msg) + { + future.setSuccess(msg.payload); + } + + @Override + public void onFailure(InetAddressAndPort from, RequestFailure failure) + { + future.setFailure(new RuntimeException(failure.toString())); + } + }; + try + { + MessagingService.instance().sendWithCallback(msg, to, callback); + } + catch (Throwable e) // catch any exception during sending the message and wrap it inside feture to have unified exception handling + { + future.setFailure(e); + } return future; } diff --git a/src/java/org/apache/cassandra/net/Verb.java b/src/java/org/apache/cassandra/net/Verb.java index d24c9e64adff..14b78405a51d 100644 --- a/src/java/org/apache/cassandra/net/Verb.java +++ b/src/java/org/apache/cassandra/net/Verb.java @@ -44,6 +44,8 @@ import org.apache.cassandra.db.TruncateRequest; import org.apache.cassandra.db.TruncateResponse; import org.apache.cassandra.db.TruncateVerbHandler; +import org.apache.cassandra.db.compression.CompressionDictionaryUpdateMessage; +import org.apache.cassandra.db.compression.CompressionDictionaryUpdateVerbHandler; import org.apache.cassandra.exceptions.RequestFailure; import org.apache.cassandra.gms.GossipDigestAck; import org.apache.cassandra.gms.GossipDigestAck2; @@ -372,6 +374,9 @@ public enum Verb ACCORD_FETCH_TOPOLOGY_RSP (169, P0, shortTimeout, FETCH_METADATA, () -> accordEmbedded(FetchTopologies.responseSerializer), RESPONSE_HANDLER), ACCORD_FETCH_TOPOLOGY_REQ (170, P0, shortTimeout, FETCH_METADATA, () -> accordEmbedded(FetchTopologies.serializer), () -> FetchTopologies.handler, ACCORD_FETCH_TOPOLOGY_RSP), + DICTIONARY_UPDATE_RSP (171, P1, rpcTimeout, REQUEST_RESPONSE, () -> NoPayload.serializer, RESPONSE_HANDLER ), + DICTIONARY_UPDATE_REQ (172, P1, rpcTimeout, MISC, () -> CompressionDictionaryUpdateMessage.serializer, () -> CompressionDictionaryUpdateVerbHandler.instance, DICTIONARY_UPDATE_RSP ), + // generic failure response FAILURE_RSP (99, P0, noTimeout, REQUEST_RESPONSE, () -> RequestFailure.serializer, RESPONSE_HANDLER ), @@ -679,4 +684,4 @@ class VerbTimeouts class ResponseHandlerSupplier { static final Supplier> RESPONSE_HANDLER = () -> ResponseVerbHandler.instance; -} \ No newline at end of file +} diff --git a/src/java/org/apache/cassandra/schema/CompressionParams.java b/src/java/org/apache/cassandra/schema/CompressionParams.java index 0e7c3da13ab0..fdf184e94d96 100644 --- a/src/java/org/apache/cassandra/schema/CompressionParams.java +++ b/src/java/org/apache/cassandra/schema/CompressionParams.java @@ -160,15 +160,31 @@ public static CompressionParams lz4(int chunkLength, int maxCompressedLength) return new CompressionParams(LZ4Compressor.create(Collections.emptyMap()), chunkLength, maxCompressedLength, calcMinCompressRatio(chunkLength, maxCompressedLength), Collections.emptyMap()); } + @VisibleForTesting public static CompressionParams zstd() { - return zstd(DEFAULT_CHUNK_LENGTH); + return zstd(DEFAULT_CHUNK_LENGTH, false); } + @VisibleForTesting public static CompressionParams zstd(Integer chunkLength) { - ZstdCompressor compressor = ZstdCompressor.create(Collections.emptyMap()); - return new CompressionParams(compressor, chunkLength, Integer.MAX_VALUE, DEFAULT_MIN_COMPRESS_RATIO, Collections.emptyMap()); + return zstd(chunkLength, false); + } + + @VisibleForTesting + public static CompressionParams zstd(Integer chunkLength, boolean useDictionary) + { + return zstd(chunkLength, useDictionary, Collections.emptyMap()); + } + + @VisibleForTesting + public static CompressionParams zstd(Integer chunkLength, boolean useDictionary, Map options) + { + ICompressor compressor = useDictionary + ? ZstdDictionaryCompressor.create(options) + : ZstdCompressor.create(options); + return new CompressionParams(compressor, chunkLength, Integer.MAX_VALUE, DEFAULT_MIN_COMPRESS_RATIO, options); } @VisibleForTesting @@ -223,6 +239,18 @@ public boolean isEnabled() return sstableCompressor != null; } + /** + * Checks if dictionary compression is enabled for this configuration. + * Dictionary compression is enabled when both compression is enabled and + * the compressor supports dictionary-based compression. + * + * @return {@code true} if dictionary compression is enabled, {@code false} otherwise. + */ + public boolean isDictionaryCompressionEnabled() + { + return isEnabled() && sstableCompressor instanceof IDictionaryCompressor; + } + /** * Returns the SSTable compressor. * @return the SSTable compressor or {@code null} if compression is disabled. diff --git a/src/java/org/apache/cassandra/schema/SystemDistributedKeyspace.java b/src/java/org/apache/cassandra/schema/SystemDistributedKeyspace.java index d50621a3a15c..3488858c79c8 100644 --- a/src/java/org/apache/cassandra/schema/SystemDistributedKeyspace.java +++ b/src/java/org/apache/cassandra/schema/SystemDistributedKeyspace.java @@ -30,6 +30,8 @@ import java.util.UUID; import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; + import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableMap; @@ -49,6 +51,7 @@ import org.apache.cassandra.db.Keyspace; import org.apache.cassandra.dht.Range; import org.apache.cassandra.dht.Token; +import org.apache.cassandra.db.compression.CompressionDictionary; import org.apache.cassandra.locator.InetAddressAndPort; import org.apache.cassandra.repair.CommonRange; import org.apache.cassandra.repair.messages.RepairOption; @@ -56,7 +59,6 @@ import org.apache.cassandra.utils.TimeUUID; import static java.lang.String.format; - import static org.apache.cassandra.utils.ByteBufferUtil.bytes; public final class SystemDistributedKeyspace @@ -83,10 +85,11 @@ private SystemDistributedKeyspace() * gen 5: add ttl and TWCS to repair_history tables * gen 6: add denylist table * gen 7: add auto_repair_history and auto_repair_priority tables for AutoRepair feature + * gen 8: add compression_dictionaries for dictionary-based compression algorithms (e.g. zstd) * * // TODO: TCM - how do we evolve these tables? */ - public static final long GENERATION = 7; + public static final long GENERATION = 8; public static final String REPAIR_HISTORY = "repair_history"; @@ -100,7 +103,12 @@ private SystemDistributedKeyspace() public static final String AUTO_REPAIR_PRIORITY = "auto_repair_priority"; - public static final Set TABLE_NAMES = ImmutableSet.of(REPAIR_HISTORY, PARENT_REPAIR_HISTORY, VIEW_BUILD_STATUS, PARTITION_DENYLIST_TABLE, AUTO_REPAIR_HISTORY, AUTO_REPAIR_PRIORITY); + public static final String COMPRESSION_DICTIONARIES = "compression_dictionaries"; + + public static final Set TABLE_NAMES = ImmutableSet.of(REPAIR_HISTORY, PARENT_REPAIR_HISTORY, + VIEW_BUILD_STATUS, PARTITION_DENYLIST_TABLE, + AUTO_REPAIR_HISTORY, AUTO_REPAIR_PRIORITY, + COMPRESSION_DICTIONARIES); public static final String REPAIR_HISTORY_CQL = "CREATE TABLE IF NOT EXISTS %s (" + "keyspace_name text," @@ -185,6 +193,18 @@ private SystemDistributedKeyspace() private static final TableMetadata AutoRepairPriorityTable = parse(AUTO_REPAIR_PRIORITY, "Auto repair priority for each group", AUTO_REPAIR_PRIORITY_CQL).build(); + public static final String COMPRESSION_DICTIONARIES_CQL = "CREATE TABLE IF NOT EXISTS %s (" + + "keyspace_name text," + + "table_name text," + + "kind text," + + "dict_id bigint," + + "dict blob," + + "PRIMARY KEY ((keyspace_name, table_name), dict_id)) " + + "WITH CLUSTERING ORDER BY (dict_id DESC)"; // in order to retrieve the latest dictionary; the contract is the newer the dictionary the larger the dict_id + + private static final TableMetadata CompressionDictionariesTable = + parse(COMPRESSION_DICTIONARIES, "Compression dictionaries for applicable tables", COMPRESSION_DICTIONARIES_CQL).build(); + private static TableMetadata.Builder parse(String table, String description, String cql) { return CreateTableStatement.parse(format(cql, table), SchemaConstants.DISTRIBUTED_KEYSPACE_NAME) @@ -197,7 +217,10 @@ public static KeyspaceMetadata metadata() { return KeyspaceMetadata.create(SchemaConstants.DISTRIBUTED_KEYSPACE_NAME, KeyspaceParams.simple(Math.max(DEFAULT_RF, DatabaseDescriptor.getDefaultKeyspaceRF())), - Tables.of(RepairHistory, ParentRepairHistory, ViewBuildStatus, PartitionDenylistTable, AutoRepairHistoryTable, AutoRepairPriorityTable)); + Tables.of(RepairHistory, ParentRepairHistory, + ViewBuildStatus, PartitionDenylistTable, + AutoRepairHistoryTable, AutoRepairPriorityTable, + CompressionDictionariesTable)); } public static void startParentRepair(TimeUUID parent_id, String keyspaceName, String[] cfnames, RepairOption options) @@ -382,20 +405,97 @@ public static void setViewRemoved(String keyspaceName, String viewName) forceBlockingFlush(VIEW_BUILD_STATUS, ColumnFamilyStore.FlushReason.INTERNALLY_FORCED); } - private static void processSilent(String fmtQry, String... values) + /** + * Stores a compression dictionary for a given keyspace and table in the distributed system keyspace. + * + * @param keyspaceName the keyspace name to associate with the dictionary + * @param tableName the table name to associate with the dictionary + * @param dictionary the compression dictionary to store + */ + public static void storeCompressionDictionary(String keyspaceName, String tableName, CompressionDictionary dictionary) + { + String query = "INSERT INTO %s.%s (keyspace_name, table_name, kind, dict_id, dict) VALUES ('%s', '%s', '%s', %s, ?)"; + String fmtQuery = format(query, + SchemaConstants.DISTRIBUTED_KEYSPACE_NAME, + COMPRESSION_DICTIONARIES, + keyspaceName, + tableName, + dictionary.kind(), + dictionary.identifier().id); + noThrow(fmtQuery, + () -> QueryProcessor.process(fmtQuery, ConsistencyLevel.ONE, + Collections.singletonList(ByteBuffer.wrap(dictionary.rawDictionary())))); + } + + /** + * Retrieves the latest compression dictionary for a given keyspace and table. + * + * @param keyspaceName the keyspace name to retrieve the dictionary for + * @param tableName the table name to retrieve the dictionary for + * @return the latest compression dictionary for the specified keyspace and table, + * or null if no dictionary exists or if an error occurs during retrieval + */ + @Nullable + public static CompressionDictionary retrieveLatestCompressionDictionary(String keyspaceName, String tableName) + { + String query = "SELECT kind, dict_id, dict FROM %s.%s WHERE keyspace_name='%s' AND table_name='%s' LIMIT 1"; + String fmtQuery = format(query, SchemaConstants.DISTRIBUTED_KEYSPACE_NAME, COMPRESSION_DICTIONARIES, keyspaceName, tableName); + try + { + UntypedResultSet.Row row = QueryProcessor.execute(fmtQuery, ConsistencyLevel.ONE).one(); + return CompressionDictionary.createFromRow(row); + } + catch (Exception e) + { + return null; + } + } + + /** + * Retrieves a specific compression dictionary for a given keyspace and table. + * + * @param keyspaceName the keyspace name to retrieve the dictionary for + * @param tableName the table name to retrieve the dictionary for + * @param dictionaryId the dictionary id to retrieve the dictionary for + * @return the compression dictionary identified by the specified keyspace, table and dictionaryId, + * or null if no dictionary exists or if an error occurs during retrieval + */ + public static CompressionDictionary retrieveCompressionDictionary(String keyspaceName, String tableName, CompressionDictionary.DictId dictionaryId) { + String query = "SELECT kind, dict_id, dict FROM %s.%s WHERE keyspace_name='%s' AND table_name='%s' AND dict_id=%s"; + String fmtQuery = format(query, SchemaConstants.DISTRIBUTED_KEYSPACE_NAME, COMPRESSION_DICTIONARIES, keyspaceName, tableName, dictionaryId.id); try { + UntypedResultSet.Row row = QueryProcessor.execute(fmtQuery, ConsistencyLevel.ONE).one(); + return CompressionDictionary.createFromRow(row); + } + catch (Exception e) + { + return null; + } + } + + private static void processSilent(String fmtQry, String... values) + { + noThrow(fmtQry, () -> { List valueList = new ArrayList<>(values.length); for (String v : values) { valueList.add(bytes(v)); } QueryProcessor.process(fmtQry, ConsistencyLevel.ANY, valueList); + }); + } + + private static void noThrow(String fmtQry, Runnable queryExec) + { + try + { + queryExec.run(); } catch (Throwable t) { - logger.error("Error executing query "+fmtQry, t); + logger.error("Error executing query " + fmtQry, t); } } diff --git a/src/java/org/apache/cassandra/tools/NodeProbe.java b/src/java/org/apache/cassandra/tools/NodeProbe.java index b59655134aaa..7f270ae74dca 100644 --- a/src/java/org/apache/cassandra/tools/NodeProbe.java +++ b/src/java/org/apache/cassandra/tools/NodeProbe.java @@ -88,10 +88,11 @@ import org.apache.cassandra.db.ColumnFamilyStoreMBean; import org.apache.cassandra.db.compaction.CompactionManager; import org.apache.cassandra.db.compaction.CompactionManagerMBean; -import org.apache.cassandra.db.virtual.CIDRFilteringMetricsTable; -import org.apache.cassandra.db.virtual.CIDRFilteringMetricsTableMBean; +import org.apache.cassandra.db.compression.CompressionDictionaryManagerMBean; import org.apache.cassandra.db.guardrails.Guardrails; import org.apache.cassandra.db.guardrails.GuardrailsMBean; +import org.apache.cassandra.db.virtual.CIDRFilteringMetricsTable; +import org.apache.cassandra.db.virtual.CIDRFilteringMetricsTableMBean; import org.apache.cassandra.fql.FullQueryLoggerOptions; import org.apache.cassandra.fql.FullQueryLoggerOptionsCompositeData; import org.apache.cassandra.gms.FailureDetector; @@ -2682,6 +2683,61 @@ public void setMixedMajorVersionRepairEnabled(boolean enabled) { autoRepairProxy.setMixedMajorVersionRepairEnabled(enabled); } + + /** + * Triggers compression dictionary training for the specified table. + * + * @param keyspace the keyspace name + * @param table the table name + * @param options options for the training process (currently unused, reserved for future extensions) + * @throws IOException if there's an error accessing the MBean + * @throws IllegalArgumentException if table doesn't support dictionary compression + */ + public void trainCompressionDictionary(String keyspace, String table, Map options) throws IOException + { + getDictionaryManagerProxy(keyspace, table).train(options); + } + + /** + * Gets the compression dictionary training status for the specified table. + * + * @param keyspace the keyspace name + * @param table the table name + * @return the training status as string + * @throws IOException if there's an error accessing the MBean + */ + public String getCompressionDictionaryTrainingStatus(String keyspace, String table) throws IOException + { + return getDictionaryManagerProxy(keyspace, table).getTrainingStatus(); + } + + /** + * Updates the sampling rate for compression dictionary training. + * + * @param keyspace the keyspace name + * @param table the table name + * @param samplingRate the new sampling rate (1 = sample every time, 2 = sample every 2nd time, etc.) + * @throws IOException if there's an error accessing the MBean + */ + public void updateCompressionDictionaryTrainingSamplingRate(String keyspace, String table, int samplingRate) throws IOException + { + getDictionaryManagerProxy(keyspace, table).updateSamplingRate(samplingRate); + } + + private CompressionDictionaryManagerMBean getDictionaryManagerProxy(String keyspace, String table) throws IOException + { + // Construct table-specific MBean name + String mbeanName = "org.apache.cassandra.db.compression:type=CompressionDictionaryManager,keyspace=" + keyspace + ",table=" + table; + try + { + ObjectName objectName = new ObjectName(mbeanName); + return JMX.newMBeanProxy(mbeanServerConn, objectName, CompressionDictionaryManagerMBean.class); + } + catch (MalformedObjectNameException e) + { + throw new IOException("Invalid keyspace or table name", e); + } + } } class ColumnFamilyStoreMBeanIterator implements Iterator> diff --git a/src/java/org/apache/cassandra/tools/SSTableMetadataViewer.java b/src/java/org/apache/cassandra/tools/SSTableMetadataViewer.java index 256c80d26903..11df06158d7d 100644 --- a/src/java/org/apache/cassandra/tools/SSTableMetadataViewer.java +++ b/src/java/org/apache/cassandra/tools/SSTableMetadataViewer.java @@ -322,7 +322,7 @@ private void printSStableMetadata(File file, boolean scan) throws IOException CompactionMetadata compaction = statsComponent.compactionMetadata(); SerializationHeader.Component header = statsComponent.serializationHeader(); Class compressorClass = null; - try (CompressionMetadata compression = CompressionInfoComponent.loadIfExists(descriptor)) + try (CompressionMetadata compression = CompressionInfoComponent.loadIfExists(descriptor, null)) { compressorClass = compression != null ? compression.compressor().getClass() : null; } diff --git a/src/java/org/apache/cassandra/tools/nodetool/NodetoolCommand.java b/src/java/org/apache/cassandra/tools/nodetool/NodetoolCommand.java index c20c6936fc4d..3415c52311a7 100644 --- a/src/java/org/apache/cassandra/tools/nodetool/NodetoolCommand.java +++ b/src/java/org/apache/cassandra/tools/nodetool/NodetoolCommand.java @@ -207,6 +207,7 @@ TableStats.class, TopPartitions.class, TpStats.class, + TrainCompressionDictionary.class, TruncateHints.class, UpdateCIDRGroup.class, UpgradeSSTable.class, diff --git a/src/java/org/apache/cassandra/tools/nodetool/TrainCompressionDictionary.java b/src/java/org/apache/cassandra/tools/nodetool/TrainCompressionDictionary.java new file mode 100644 index 000000000000..4196bd38c805 --- /dev/null +++ b/src/java/org/apache/cassandra/tools/nodetool/TrainCompressionDictionary.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.tools.nodetool; + +import java.io.PrintStream; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import com.google.common.util.concurrent.Uninterruptibles; + +import org.apache.cassandra.db.compression.ICompressionDictionaryTrainer.TrainingStatus; +import org.apache.cassandra.tools.NodeProbe; +import picocli.CommandLine.Command; +import picocli.CommandLine.Option; +import picocli.CommandLine.Parameters; + +@Command(name = "traincompressiondictionary", +description = "Manually trigger compression dictionary training for a table") +public class TrainCompressionDictionary extends AbstractCommand +{ + @Parameters(index = "0", description = "The keyspace name", arity = "1") + private String keyspace; + + @Parameters(index = "1", description = "The table name", arity = "1") + private String table; + + @Option(names = {"-d", "--max-sampling-duration"}, + description = "Maximum time to collect samples before training dictionary (default: 600 seconds)") + private int maxSamplingDurationSeconds = 600; + + @Option(names = {"-r", "--sampling-rate"}, + description = "Sampling rate as a double value in range (0, 1]. 1.0 means sample all data, 0.5 means sample 50%% of data") + private Double samplingRate; + + @Option(names = {"-a", "--async"}, + description = "Run training asynchronously without waiting for completion") + private boolean async = false; + + @Option(names = {"-s", "--status"}, + description = "Show current training status instead of starting new training") + private boolean showStatus = false; + + @Override + public void execute(NodeProbe probe) + { + if (showStatus) + { + showTrainingStatus(probe); + return; + } + + PrintStream out = probe.output().out; + PrintStream err = probe.output().err; + if (maxSamplingDurationSeconds <= 0) + { + err.printf("Invalid value for max-sampling-duration: %s%n", maxSamplingDurationSeconds); + System.exit(1); + } + + if (samplingRate != null && (samplingRate <= 0.0 || samplingRate > 1.0)) + { + err.printf("Invalid value for sampling-rate: %s. Must be in range (0, 1]%n", samplingRate); + System.exit(1); + } + try + { + out.printf("Starting compression dictionary training for %s.%s...%n", keyspace, table); + out.printf("Will collect samples for up to %d seconds before training%n", maxSamplingDurationSeconds); + if (samplingRate != null) + { + out.printf("Using sampling rate: %.2f (%.1f%%)%n", samplingRate, samplingRate * 100); + } + + // Build options map + Map options = new HashMap<>(); + options.put("maxSamplingDurationSeconds", String.valueOf(maxSamplingDurationSeconds)); + + probe.trainCompressionDictionary(keyspace, table, options); + + // Update sampling rate if provided (after training has started) + if (samplingRate != null) + { + // Convert from double (0, 1] to integer format (1/rate) + // Examples: 1.0 -> 1 (sample every time), 0.5 -> 2 (roughly sample every 2nd), 0.1 -> 10 (roughly sample every 10th) + int integerSamplingRate = (int) Math.round(1.0 / samplingRate); + probe.updateCompressionDictionaryTrainingSamplingRate(keyspace, table, integerSamplingRate); + } + + if (async) + { + out.printf("Training started asynchronously for %s.%s%n", keyspace, table); + out.printf("Use 'nodetool traincompressiondictionary --status %s %s' to check progress.%n", + keyspace, table); + return; + } + + // Wait for completion (training will start automatically after sampling period) + out.println("Collecting samples and training. (Since the trainer samples chunk data on " + + "writing to new SSTable, you might consider running nodetool 'flush' along " + + "with this command to have chunk available for sampling)"); + long maxWaitMillis = TimeUnit.SECONDS.toMillis(maxSamplingDurationSeconds + 300); // Add 5 minutes for training + long startTime = System.currentTimeMillis(); + + while (System.currentTimeMillis() - startTime < maxWaitMillis) + { + String statusStr = probe.getCompressionDictionaryTrainingStatus(keyspace, table); + TrainingStatus status = TrainingStatus.valueOf(statusStr); + if (TrainingStatus.COMPLETED == status) + { + out.printf("%nTraining completed successfully for %s.%s%n", keyspace, table); + return; + } + else if (TrainingStatus.FAILED == status) + { + err.printf("%nTraining failed for %s.%s%n", keyspace, table); + System.exit(1); + } + + out.print('.'); + + Uninterruptibles.sleepUninterruptibly(2, TimeUnit.SECONDS); + } + + err.printf("%nTraining did not complete within expected timeframe (%d seconds sampling + 5 minutes training). Use --status to check current state.%n", + maxSamplingDurationSeconds); + System.exit(1); + } + catch (Exception e) + { + err.printf("Failed to trigger training: %s%n", e.getMessage()); + System.exit(1); + } + } + + private void showTrainingStatus(NodeProbe probe) + { + PrintStream out = probe.output().out; + PrintStream err = probe.output().err; + String statusStr = null; + try + { + statusStr = probe.getCompressionDictionaryTrainingStatus(keyspace, table); + } + catch (Exception e) + { + err.printf("Failed to get training status: %s%n", e.getMessage()); + System.exit(1); + } + + TrainingStatus status = TrainingStatus.valueOf(statusStr); + switch (status) + { + case NOT_STARTED: + out.printf("Trainer is not running for %s.%s%n", keyspace, table); + break; + case SAMPLING: + out.printf("Trainer is collecting sample data for %s.%s%n", keyspace, table); + break; + case TRAINING: + out.printf("Training is in progress for %s.%s%n", keyspace, table); + break; + case COMPLETED: + out.printf("Training is completed for %s.%s%n", keyspace, table); + break; + case FAILED: + err.printf("Training failed for %s.%s%n", keyspace, table); + break; + default: + err.printf("Encountered unexpected training status for %s.%s: %s%n", keyspace, table, status); + } + } +} diff --git a/src/java/org/apache/cassandra/utils/StorageCompatibilityMode.java b/src/java/org/apache/cassandra/utils/StorageCompatibilityMode.java index 2969597c2381..cfd61bd0b92c 100644 --- a/src/java/org/apache/cassandra/utils/StorageCompatibilityMode.java +++ b/src/java/org/apache/cassandra/utils/StorageCompatibilityMode.java @@ -35,6 +35,13 @@ public enum StorageCompatibilityMode */ CASSANDRA_4(4), + /** + * Similar to CASSANDRA_4. + * The new features in 6.0 are + * - ZSTD dictioanry-based compression. Once SSTables are compressed with dictioanry, they cannot be rolled back. + */ + CASSANDRA_5(5), + /** * Use the storage formats of the current version, but disabling features that are not compatible with any * not-upgraded nodes in the cluster. Use this during rolling upgrades to a new major Cassandra version. Once all diff --git a/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorBenchBase.java b/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorBenchBase.java new file mode 100644 index 000000000000..12036d6d77dc --- /dev/null +++ b/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorBenchBase.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.test.microbench; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Map; +import java.util.Random; +import java.util.UUID; + +import com.github.luben.zstd.ZstdDictTrainer; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import org.apache.cassandra.db.compression.ZstdCompressionDictionary; +import org.apache.cassandra.io.compress.ZstdDictionaryCompressor; + +// The bench takes over 20 minutes to finish +@State(Scope.Benchmark) +public abstract class ZstdDictionaryCompressorBenchBase +{ + @Param({"4096", "16384", "65536"}) + protected int dataSize; + + @Param({"CASSANDRA_LIKE", "COMPRESSIBLE", "MIXED"}) + protected DataType dataType; + + @Param({"0", "65536"}) + protected int dictionarySize; + + @Param({"3", "5", "7"}) + protected int compressionLevel; + + protected byte[] inputData; + protected ByteBuffer inputBuffer; + protected ByteBuffer compressedBuffer; + protected ByteBuffer decompressedBuffer; + protected ZstdDictionaryCompressor compressor; + protected ZstdDictionaryCompressor noDictCompressor; + protected ZstdCompressionDictionary dictionary; + + public enum DataType + { + CASSANDRA_LIKE, COMPRESSIBLE, MIXED + } + + @Setup(Level.Trial) + public void setupTrial() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Setup(Level.Iteration) + public void setupIteration() throws IOException + { + Random random = new Random(42); + + // Generate test data based on type + inputData = generateTestData(dataType, dataSize, random); + + // Create direct ByteBuffers (required by ZSTD) + inputBuffer = ByteBuffer.allocateDirect(dataSize); + inputBuffer.put(inputData); + inputBuffer.flip(); + + // Allocate buffers with extra space for compression overhead + int maxCompressedSize = dataSize + 1024; + compressedBuffer = ByteBuffer.allocateDirect(maxCompressedSize); + decompressedBuffer = ByteBuffer.allocateDirect(dataSize); + + // Create dictionary if needed + if (dictionarySize != 0) + { + dictionary = createDictionary(dataType, dictionarySize, random); + Map options = Map.of("compression_level", String.valueOf(compressionLevel)); + compressor = ZstdDictionaryCompressor.create(options).getOrCopyWithDictionary(dictionary); + } + else + { + Map options = Map.of("compression_level", String.valueOf(compressionLevel)); + compressor = ZstdDictionaryCompressor.create(options); + } + + // Always create a no-dictionary compressor for comparison + Map options = Map.of("compression_level", String.valueOf(compressionLevel)); + noDictCompressor = ZstdDictionaryCompressor.create(options); + } + + @TearDown(Level.Iteration) + public void tearDown() + { + if (dictionary != null) + { + dictionary.close(); + dictionary = null; + } + ZstdDictionaryCompressor.invalidateCache(); + } + + protected byte[] generateTestData(DataType type, int size, Random random) + { + byte[] data = new byte[size]; + + switch (type) + { + case CASSANDRA_LIKE: + generateCassandraLikeData(data, random); + break; + + case COMPRESSIBLE: + generateCompressibleData(data, random); + break; + + case MIXED: + generateMixedData(data, random); + break; + } + + return data; + } + + private void generateCassandraLikeData(byte[] data, Random random) + { + StringBuilder sb = new StringBuilder(); + String[] patterns = { + "user_id_", "timestamp_", "session_", "event_type_", + "metadata_", "value_", "status_", "location_" + }; + + while (sb.length() < data.length) + { + String pattern = patterns[random.nextInt(patterns.length)]; + sb.append(pattern).append(UUID.randomUUID().toString()).append("|"); + sb.append("timestamp:").append(System.currentTimeMillis() + random.nextInt(86400000)).append("|"); + sb.append("value:").append(random.nextDouble()).append("|"); + sb.append("count:").append(random.nextInt(1000)).append("\n"); + } + + byte[] generated = sb.substring(0, Math.min(data.length, sb.length())).getBytes(); + System.arraycopy(generated, 0, data, 0, generated.length); + + // Fill remaining space with random data if needed + if (generated.length < data.length) + { + byte[] remaining = new byte[data.length - generated.length]; + random.nextBytes(remaining); + System.arraycopy(remaining, 0, data, generated.length, remaining.length); + } + } + + private void generateCompressibleData(byte[] data, Random random) + { + String pattern = "The quick brown fox jumps over the lazy dog. This is a highly compressible pattern that repeats. "; + byte[] patternBytes = pattern.getBytes(); + + for (int i = 0; i < data.length; i++) + { + data[i] = patternBytes[i % patternBytes.length]; + } + + // Add some randomness (10%) + for (int i = 0; i < data.length / 10; i++) + { + data[random.nextInt(data.length)] = (byte) random.nextInt(256); + } + } + + private void generateMixedData(byte[] data, Random random) + { + int quarter = data.length / 4; + + // 25% random + random.nextBytes(data); + + // 25% compressible + byte[] compressible = new byte[quarter]; + generateCompressibleData(compressible, random); + System.arraycopy(compressible, 0, data, quarter, quarter); + + // 50% Cassandra-like + byte[] cassandraLike = new byte[data.length - 2 * quarter]; + generateCassandraLikeData(cassandraLike, random); + System.arraycopy(cassandraLike, 0, data, 2 * quarter, cassandraLike.length); + } + + private ZstdCompressionDictionary createDictionary(DataType dataType, int dictSize, Random random) + { + // Generate training samples + byte[][] samples = new byte[100][]; + int totalSampleSize = 0; + for (int i = 0; i < samples.length; i++) + { + samples[i] = generateTestData(dataType, Math.min(1024, dataSize), random); + totalSampleSize += samples[i].length; + } + + // Train dictionary + ZstdDictTrainer trainer = new ZstdDictTrainer(totalSampleSize, dictSize); + for (byte[] sample : samples) + { + trainer.addSample(sample); + } + + byte[] dictData = trainer.trainSamples(); + DictId dictId = new DictId(Kind.ZSTD, 0); + return new ZstdCompressionDictionary(dictId, dictData); + } +} diff --git a/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorRatioBench.java b/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorRatioBench.java new file mode 100644 index 000000000000..95e208369292 --- /dev/null +++ b/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorRatioBench.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.test.microbench; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.cassandra.config.DatabaseDescriptor; + +// This is not really a bench, but share common utilties from the base class. +// Running the class output compression ratio and dictionary effectiveness of different configurations +public class ZstdDictionaryCompressorRatioBench extends ZstdDictionaryCompressorBenchBase +{ + private static class CompressionResult + { + final String configuration; + final double compressionRatio; + final double dictionaryEffectiveness; + + CompressionResult(String configuration, double compressionRatio, double dictionaryEffectiveness) + { + this.configuration = configuration; + this.compressionRatio = compressionRatio; + this.dictionaryEffectiveness = dictionaryEffectiveness; + } + } + + private CompressionResult measureCompressionRatio() throws IOException + { + // Compress with current compressor + inputBuffer.rewind(); + compressedBuffer.clear(); + compressor.compress(inputBuffer, compressedBuffer); + int compressedSize = compressedBuffer.position(); + + // Compress without dictionary for comparison + inputBuffer.rewind(); + compressedBuffer.clear(); + noDictCompressor.compress(inputBuffer, compressedBuffer); + int noDictCompressedSize = compressedBuffer.position(); + + // Calculate ratios + double compressionRatio = (double) inputBuffer.limit() / compressedSize; + double dictionaryEffectiveness = (double) noDictCompressedSize / compressedSize; + + // Create configuration string + String config = String.format("%s_%s_L%d_Chunk%dKiB", + dataType, + dictionarySize == 0 ? "NoDict" : "WithDict", + compressionLevel, + dataSize / 1024); + + return new CompressionResult(config, compressionRatio, dictionaryEffectiveness); + } + + public static void main(String[] args) throws Exception + { + DatabaseDescriptor.daemonInitialization(); + + List allResults = new ArrayList<>(); + + // Define test parameters + DataType[] dataTypes = {DataType.CASSANDRA_LIKE, DataType.COMPRESSIBLE, DataType.MIXED}; + int[] dictionarySizes = {0, 65536}; + int[] compressionLevels = {3, 5, 7}; + int[] dataSizes = {4096, 16384, 65536}; + + System.out.println("Running ZSTD Dictionary Compressor Ratio Measurements..."); + System.out.println("Total configurations: " + (dataTypes.length * dictionarySizes.length * compressionLevels.length * dataSizes.length)); + + int configCount = 0; + for (DataType dataType : dataTypes) + { + for (int dictionarySize : dictionarySizes) + { + for (int compressionLevel : compressionLevels) + { + for (int dataSize : dataSizes) + { + configCount++; + ZstdDictionaryCompressorRatioBench bench = new ZstdDictionaryCompressorRatioBench(); + bench.dataType = dataType; + bench.dictionarySize = dictionarySize; + bench.compressionLevel = compressionLevel; + bench.dataSize = dataSize; + + try + { + bench.setupIteration(); + CompressionResult result = bench.measureCompressionRatio(); + allResults.add(result); + bench.tearDown(); + } + catch (Exception e) + { + System.err.println("Failed to process configuration: " + e.getMessage()); + e.printStackTrace(); + } + } + } + } + } + + // Print consolidated results + printConsolidatedResults(allResults); + } + + private static void printConsolidatedResults(List results) + { + StringBuilder report = new StringBuilder(); + report.append("\n").append("=".repeat(100)).append("\n"); + report.append("ZSTD DICTIONARY COMPRESSOR RATIO RESULTS").append("\n"); + report.append("=".repeat(100)).append("\n"); + report.append(String.format("%-50s %-20s %-20s%n", "Configuration", "Compression Ratio", "Dictionary Effectiveness")); + report.append("-".repeat(100)).append("\n"); + + for (CompressionResult entry : results) + { + report.append(String.format("%-50s %-20.3f %-20.3f%n", + entry.configuration, entry.compressionRatio, entry.dictionaryEffectiveness)); + } + + report.append("=".repeat(100)).append("\n"); + report.append("Compression Ratio: Original Size / Compressed Size (higher is better)").append("\n"); + report.append("Dictionary Effectiveness: Non-Dict Size / Dict Size (higher is better)").append("\n"); + report.append("=".repeat(100)).append("\n"); + + System.out.print(report.toString()); + } +} diff --git a/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorThroughputBench.java b/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorThroughputBench.java new file mode 100644 index 000000000000..601ad53ebd56 --- /dev/null +++ b/test/microbench/org/apache/cassandra/test/microbench/ZstdDictionaryCompressorThroughputBench.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.test.microbench; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +@BenchmarkMode({Mode.Throughput, Mode.AverageTime}) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Warmup(iterations = 1, time = 1, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 2, time = 2, timeUnit = TimeUnit.SECONDS) +@Fork(value = 1, jvmArgsAppend = "-Xmx1G") +public class ZstdDictionaryCompressorThroughputBench extends ZstdDictionaryCompressorBenchBase +{ + @Benchmark + public void compressionThroughput(Blackhole bh) throws IOException + { + inputBuffer.rewind(); + compressedBuffer.clear(); + + compressor.compress(inputBuffer, compressedBuffer); + bh.consume(compressedBuffer.position()); + } + + @Benchmark + public void decompressionThroughput(Blackhole bh) throws IOException + { + // First compress the data + inputBuffer.rewind(); + compressedBuffer.clear(); + compressor.compress(inputBuffer, compressedBuffer); + + // Then decompress it + compressedBuffer.flip(); + decompressedBuffer.clear(); + compressor.uncompress(compressedBuffer, decompressedBuffer); + + bh.consume(decompressedBuffer.position()); + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryCacheTest.java b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryCacheTest.java new file mode 100644 index 000000000000..54f3f30818fe --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryCacheTest.java @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import com.github.luben.zstd.ZstdDictTrainer; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.config.DatabaseDescriptor; + +import static org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import static org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import static org.assertj.core.api.Assertions.assertThat; + +public class CompressionDictionaryCacheTest +{ + private static final String TEST_PATTERN = "The quick brown fox jumps over the lazy dog. "; + + private CompressionDictionaryCache cache; + private ZstdCompressionDictionary testDict1; + private ZstdCompressionDictionary testDict2; + private ZstdCompressionDictionary testDict3; + + @BeforeClass + public static void setUpClass() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Before + public void setUp() + { + cache = new CompressionDictionaryCache(); + testDict1 = createTestDictionary(1); + testDict2 = createTestDictionary(2); + testDict3 = createTestDictionary(3); + } + + @After + public void tearDown() + { + if (cache != null) + { + cache.close(); + } + + // Close dictionaries if not already closed + closeQuietly(testDict1); + closeQuietly(testDict2); + closeQuietly(testDict3); + } + + // Basic cache operations tests + + @Test + public void testGetCurrentInitiallyNull() + { + assertThat(cache.getCurrent()) + .as("Current dictionary should be null initially") + .isNull(); + } + + @Test + public void testAddAndGet() + { + cache.add(testDict1); + + CompressionDictionary retrieved = cache.get(testDict1.identifier()); + assertThat(retrieved) + .as("Should retrieve the same dictionary instance") + .isSameAs(testDict1); + } + + @Test + public void testGetNonExistentDictionary() + { + DictId nonExistentId = new DictId(Kind.ZSTD, 999); + assertThat(cache.get(nonExistentId)) + .as("Should return null for non-existent dictionary") + .isNull(); + } + + @Test + public void testAddMultipleDictionaries() + { + cache.add(testDict1); + cache.add(testDict2); + cache.add(testDict3); + + assertThat(cache.get(testDict1.identifier())).isSameAs(testDict1); + assertThat(cache.get(testDict2.identifier())).isSameAs(testDict2); + assertThat(cache.get(testDict3.identifier())).isSameAs(testDict3); + } + + @Test + public void testSetCurrentIfNewer() + { + cache.setCurrentIfNewer(testDict1); + assertThat(cache.getCurrent()) + .as("Should set first dictionary as current") + .isSameAs(testDict1); + + // Verify it was also added to cache + assertThat(cache.get(testDict1.identifier())).isSameAs(testDict1); + } + + @Test + public void testSetCurrentWithNewerDictionary() + { + cache.setCurrentIfNewer(testDict1); + cache.setCurrentIfNewer(testDict2); + + assertThat(cache.getCurrent()) + .as("Should update to newer dictionary") + .isSameAs(testDict2); + + // Both should be in cache + assertThat(cache.get(testDict1.identifier())).isSameAs(testDict1); + assertThat(cache.get(testDict2.identifier())).isSameAs(testDict2); + } + + @Test + public void testSetCurrentWithOlderDictionary() + { + cache.setCurrentIfNewer(testDict2); + cache.setCurrentIfNewer(testDict1); // older dictionary + + assertThat(cache.getCurrent()) + .as("Should keep newer dictionary as current") + .isSameAs(testDict2); + + // Both should be in cache + assertThat(cache.get(testDict1.identifier())).isSameAs(testDict1); + assertThat(cache.get(testDict2.identifier())).isSameAs(testDict2); + } + + @Test + public void testSetCurrentWithSameIdDictionary() + { + ZstdCompressionDictionary sameDictCopy = createTestDictionary(2); + + cache.setCurrentIfNewer(testDict2); + cache.setCurrentIfNewer(sameDictCopy); + + // Should not update since ID is the same (not newer) + assertThat(cache.getCurrent()) + .as("Should keep original dictionary as current") + .isSameAs(testDict2); + + sameDictCopy.close(); + } + + @Test + public void testSetCurrentWithNull() + { + cache.setCurrentIfNewer(testDict1); + cache.setCurrentIfNewer(null); + + // Should not change current dictionary + assertThat(cache.getCurrent()) + .as("Should keep existing dictionary as current") + .isSameAs(testDict1); + } + + @Test + public void testCacheClose() + { + cache.add(testDict1); + cache.add(testDict2); + cache.setCurrentIfNewer(testDict2); + + assertThat(cache.getCurrent()) + .as("Current should not be null before close") + .isNotNull(); + assertThat(cache.get(testDict1.identifier())) + .as("Cache should contain dict1 before close") + .isNotNull(); + + cache.close(); + + assertThat(cache.getCurrent()) + .as("Current should be null after close") + .isNull(); + assertThat(cache.get(testDict1.identifier())) + .as("Cache should not contain dict1 after close") + .isNull(); + assertThat(cache.get(testDict2.identifier())) + .as("Cache should not contain dict2 after close") + .isNull(); + } + + @Test + public void testCloseIdempotent() + { + cache.add(testDict1); + cache.setCurrentIfNewer(testDict1); + + // Close multiple times should not cause issues + cache.close(); + cache.close(); + cache.close(); + + assertThat(cache.getCurrent()) + .as("Current should remain null") + .isNull(); + assertThat(cache.get(testDict1.identifier())) + .as("Cache should remain empty") + .isNull(); + } + + @Test + public void testConcurrentAccess() throws InterruptedException + { + int threadCount = 10; + int operationsPerThread = 100; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch doneLatch = new CountDownLatch(threadCount); + AtomicInteger successCount = new AtomicInteger(0); + AtomicReference errorRef = new AtomicReference<>(); + + // Pre-populate cache + cache.add(testDict1); + cache.add(testDict2); + cache.setCurrentIfNewer(testDict2); + + for (int i = 0; i < threadCount; i++) + { + executor.submit(() -> { + try + { + startLatch.await(); + + for (int j = 0; j < operationsPerThread; j++) + { + // Mix of read operations + CompressionDictionary current = cache.getCurrent(); + cache.get(testDict1.identifier()); + cache.get(testDict2.identifier()); + + // Verify consistency + if (current != null && current.identifier().equals(testDict2.identifier())) + { + successCount.incrementAndGet(); + } + } + } + catch (Exception e) + { + errorRef.set(e); + } + finally + { + doneLatch.countDown(); + } + }); + } + + startLatch.countDown(); // Start all threads + assertThat(doneLatch.await(10, TimeUnit.SECONDS)) + .as("Threads should complete within timeout") + .isTrue(); + + executor.shutdown(); + + assertThat(errorRef.get()) + .as("No errors should occur during concurrent access") + .isNull(); + assertThat(successCount.get()) + .as("Should have successful read operations") + .isGreaterThan(0); + } + + @Test + public void testConcurrentSetCurrent() throws InterruptedException + { + int threadCount = 5; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch doneLatch = new CountDownLatch(threadCount); + + // Create multiple dictionaries with different IDs + ZstdCompressionDictionary[] dicts = new ZstdCompressionDictionary[threadCount]; + for (int i = 0; i < threadCount; i++) + { + dicts[i] = createTestDictionary(100 + i); // High IDs to ensure newer + } + + for (int i = 0; i < threadCount; i++) + { + ZstdCompressionDictionary dict = dicts[i]; + executor.submit(() -> { + try + { + startLatch.await(); + cache.setCurrentIfNewer(dict); + } + catch (Exception e) + { + // Ignore - testing thread safety + } + finally + { + doneLatch.countDown(); + } + }); + } + + startLatch.countDown(); + assertThat(doneLatch.await(5, TimeUnit.SECONDS)) + .as("Threads should complete within timeout") + .isTrue(); + + executor.shutdown(); + + // Verify that a current dictionary was set and it's one of our test dictionaries + CompressionDictionary current = cache.getCurrent(); + assertThat(current) + .as("A current dictionary should be set") + .isNotNull(); + assertThat(current.identifier().id) + .as("Current dictionary should be one of the test dictionaries") + .isBetween(100L, 100L + threadCount); + + // Clean up + for (ZstdCompressionDictionary dict : dicts) + { + closeQuietly(dict); + } + } + + private static ZstdCompressionDictionary createTestDictionary(long id) + { + try + { + // Create simple dictionary + ZstdDictTrainer trainer = new ZstdDictTrainer(10 * 1024, 1024, 3); + + // Add samples + byte[] sample = TEST_PATTERN.getBytes(); + for (int i = 0; i < 100; i++) + { + trainer.addSample(sample); + } + + byte[] dictBytes = trainer.trainSamples(); + DictId dictId = new DictId(Kind.ZSTD, id); + + return new ZstdCompressionDictionary(dictId, dictBytes); + } + catch (Exception e) + { + throw new RuntimeException("Failed to create test dictionary", e); + } + } + + private static void closeQuietly(AutoCloseable resource) + { + if (resource != null) + { + try + { + resource.close(); + } + catch (Exception e) + { + // Ignore + } + } + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryEventHandlerTest.java b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryEventHandlerTest.java new file mode 100644 index 000000000000..143f18cab6bb --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryEventHandlerTest.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.SchemaLoader; +import org.apache.cassandra.ServerTestUtils; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.net.MessagingService; +import org.apache.cassandra.net.Verb; +import org.apache.cassandra.schema.CompressionParams; +import org.apache.cassandra.schema.KeyspaceParams; +import org.apache.cassandra.schema.Schema; +import org.apache.cassandra.schema.TableId; +import org.apache.cassandra.schema.TableMetadata; +import org.apache.cassandra.tcm.membership.NodeAddresses; +import org.apache.cassandra.tcm.membership.NodeId; +import org.apache.cassandra.tcm.transformations.Register; +import org.apache.cassandra.tcm.transformations.UnsafeJoin; +import org.apache.cassandra.utils.ByteBufferUtil; +import org.apache.cassandra.utils.FBUtilities; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; + +public class CompressionDictionaryEventHandlerTest +{ + private static final String TEST_NAME = "compression_dict_event_handler_test_"; + private static final String KEYSPACE = TEST_NAME + "keyspace"; + private static final String TABLE = "test_table"; + private static final DictId TEST_DICTIONARY_ID = new DictId(Kind.ZSTD, 12345L); + + private static TableMetadata tableMetadata; + private static ColumnFamilyStore cfs; + + private CompressionDictionaryEventHandler eventHandler; + private ZstdCompressionDictionary testDictionary; + + @BeforeClass + public static void setUpClass() throws Exception + { + ServerTestUtils.prepareServerNoRegister(); + + // Create a table with dictionary compression enabled + CompressionParams compressionParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "3")); + + TableMetadata.Builder tableBuilder = TableMetadata.builder(KEYSPACE, TABLE) + .addPartitionKeyColumn("pk", org.apache.cassandra.db.marshal.UTF8Type.instance) + .addRegularColumn("data", org.apache.cassandra.db.marshal.UTF8Type.instance) + .compression(compressionParams); + + SchemaLoader.createKeyspace(KEYSPACE, + KeyspaceParams.simple(1), + tableBuilder); + + tableMetadata = Schema.instance.getTableMetadata(KEYSPACE, TABLE); + cfs = Keyspace.open(KEYSPACE).getColumnFamilyStore(TABLE); + + // Register some nodes for cluster testing + InetAddressAndPort ep1 = InetAddressAndPort.getByName("127.0.0.2:9042"); + InetAddressAndPort ep2 = InetAddressAndPort.getByName("127.0.0.3:9042"); + InetAddressAndPort ep3 = FBUtilities.getBroadcastAddressAndPort(); + + NodeId node1 = Register.register(new NodeAddresses(UUID.randomUUID(), ep1, ep1, ep1)); + NodeId node2 = Register.register(new NodeAddresses(UUID.randomUUID(), ep2, ep2, ep2)); + NodeId node3 = Register.register(new NodeAddresses(UUID.randomUUID(), ep3, ep3, ep3)); + + // Simple token distribution for testing + UnsafeJoin.unsafeJoin(node1, Collections.singleton(key(tableMetadata, 1).getToken())); + UnsafeJoin.unsafeJoin(node2, Collections.singleton(key(tableMetadata, 2).getToken())); + UnsafeJoin.unsafeJoin(node3, Collections.singleton(key(tableMetadata, 3).getToken())); + } + + @Before + public void setUp() + { + MessagingService.instance().inboundSink.clear(); + MessagingService.instance().outboundSink.clear(); + testDictionary = createTestDictionary(); + eventHandler = new CompressionDictionaryEventHandler(cfs, new CompressionDictionaryCache()); + } + + @After + public void tearDown() + { + if (testDictionary != null) + { + testDictionary.close(); + } + MessagingService.instance().inboundSink.clear(); + MessagingService.instance().outboundSink.clear(); + } + + @Test + public void testOnNewDictionaryTrained() throws InterruptedException + { + // Expect messages to 2 other nodes (excluding self) + CountDownLatch messageSentLatch = new CountDownLatch(2); + Set receivers = ConcurrentHashMap.newKeySet(2); + AtomicReference capturedMessage = new AtomicReference<>(); + + // Capture outbound messages + MessagingService.instance().outboundSink.add((message, to) -> { + if (message.verb() == Verb.DICTIONARY_UPDATE_REQ) + { + capturedMessage.set((CompressionDictionaryUpdateMessage) message.payload); + receivers.add(to); + messageSentLatch.countDown(); + } + return false; // Don't actually send + }); + + eventHandler.onNewDictionaryTrained(TEST_DICTIONARY_ID); + + // Wait for message to be processed + assertThat(messageSentLatch.await(5, TimeUnit.SECONDS)) + .as("Dictionary update notification should be sent") + .isTrue(); + + assertThat(receivers) + .as("Should not send notification to self") + .hasSize(2) + .doesNotContain(FBUtilities.getBroadcastAddressAndPort()); + + CompressionDictionaryUpdateMessage message = capturedMessage.get(); + assertThat(message) + .as("Message should be captured") + .isNotNull(); + assertThat(message.tableId) + .as("Message should contain correct table ID") + .isEqualTo(tableMetadata.id); + assertThat(message.dictionaryId) + .as("Message should contain correct dictionary ID") + .isEqualTo(TEST_DICTIONARY_ID); + } + + @Test + public void testMessageSerialization() + { + TableId testTableId = tableMetadata.id; + CompressionDictionaryUpdateMessage message = new CompressionDictionaryUpdateMessage(testTableId, TEST_DICTIONARY_ID); + + assertThat(message.tableId) + .as("Message should contain correct table ID") + .isEqualTo(testTableId); + assertThat(message.dictionaryId) + .as("Message should contain correct dictionary ID") + .isEqualTo(TEST_DICTIONARY_ID); + assertThat(CompressionDictionaryUpdateMessage.serializer) + .as("Message should have serializer") + .isNotNull(); + } + + @Test + public void testMessageSerializationRoundTrip() throws Exception + { + TableId testTableId = tableMetadata.id; + CompressionDictionaryUpdateMessage originalMessage = new CompressionDictionaryUpdateMessage(testTableId, TEST_DICTIONARY_ID); + + // Serialize + org.apache.cassandra.io.util.DataOutputBuffer out = new org.apache.cassandra.io.util.DataOutputBuffer(); + CompressionDictionaryUpdateMessage.serializer.serialize(originalMessage, out, MessagingService.current_version); + + // Deserialize + org.apache.cassandra.io.util.DataInputBuffer in = new org.apache.cassandra.io.util.DataInputBuffer(out.getData()); + CompressionDictionaryUpdateMessage deserializedMessage = + CompressionDictionaryUpdateMessage.serializer.deserialize(in, MessagingService.current_version); + + assertThat(deserializedMessage.tableId) + .as("Deserialized table ID should match") + .isEqualTo(originalMessage.tableId); + assertThat(deserializedMessage.dictionaryId) + .as("Deserialized dictionary ID should match") + .isEqualTo(originalMessage.dictionaryId); + } + + @Test + public void testSendNotificationRobustness() + { + // Test that sending notifications doesn't throw even if messaging fails + MessagingService.instance().outboundSink.add((message, to) -> { + if (message.verb() == Verb.DICTIONARY_UPDATE_REQ) + { + throw new RuntimeException("Simulated messaging failure"); + } + return false; + }); + + assertThatNoException().isThrownBy(() -> eventHandler.onNewDictionaryTrained(TEST_DICTIONARY_ID)); + } + + private static ZstdCompressionDictionary createTestDictionary() + { + byte[] dictBytes = "test dictionary data for event handler testing".getBytes(); + return new ZstdCompressionDictionary(TEST_DICTIONARY_ID, dictBytes); + } + + private static DecoratedKey key(TableMetadata metadata, int key) + { + return metadata.partitioner.decorateKey(ByteBufferUtil.bytes(key)); + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryIntegrationTest.java b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryIntegrationTest.java new file mode 100644 index 000000000000..16f313b95bc7 --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryIntegrationTest.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; + +import org.apache.cassandra.config.Config; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.cql3.CQLTester; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import org.apache.cassandra.db.compression.ICompressionDictionaryTrainer.TrainingStatus; +import org.apache.cassandra.io.sstable.format.SSTableReader; +import org.apache.cassandra.schema.CompressionParams; + +import static org.apache.cassandra.Util.spinUntilTrue; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class CompressionDictionaryIntegrationTest extends CQLTester +{ + private static final String REPEATED_DATA = "The quick brown fox jumps over the lazy dog. This text repeats for better compression. "; + + @Before + public void configureDatabaseDescriptor() + { + Config config = DatabaseDescriptor.getRawConfig(); + config.compression_dictionary_training_sampling_rate = 1; + config.compression_dictionary_training_max_total_sample_size = 128 * 1024; + config.compression_dictionary_training_max_dictionary_size = 10 * 1024; + // Ensures that data are still sampled when using the LZ4 (which is picked up when using 'fast') + // on the SSTable flushing code path + config.flush_compression = Config.FlushCompression.fast; + DatabaseDescriptor.setConfig(config); + } + + @Test + public void testEndToEndDictionaryTraining() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(table); + CompressionDictionaryManager manager = cfs.compressionDictionaryManager(); + + // Verify initial state + assertThat(manager.getTrainingStatus()) + .as("Initial training status should be NOT_STARTED or SAMPLING") + .isEqualTo(TrainingStatus.NOT_STARTED.toString()); + + // Trigger manual training + manager.train(Map.of("maxSamplingDurationSeconds", "2")); + + // Add sample data that benefits from dictionary compression + int i = 0; + while (!manager.isReady()) + { + ByteBuffer sample = ByteBuffer.wrap((REPEATED_DATA + " variation " + i++).getBytes()); + manager.addSample(sample); + } + + assertThat(manager.isReady()) + .as("Trainer should be ready to train") + .isTrue(); + + // Training should complete + spinUntilTrue(() -> manager.getTrainingStatus().equals(TrainingStatus.COMPLETED.toString()), 2); + + // Verify dictionary is available + // There could be a slight delay, as the dictionary has to be peristed to system table first. + spinUntilTrue(() -> manager.getCurrent() != null, 2); + + CompressionDictionary currentDict = manager.getCurrent(); + + assertThat(currentDict.kind()) + .as("Dictionary should be ZSTD type") + .isEqualTo(Kind.ZSTD); + + assertThat(currentDict.rawDictionary().length) + .as("Dictionary should have content") + .isGreaterThan(0); + } + + @Test + public void testEnableDisableDictionaryCompression() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(table); + CompressionDictionaryManager manager = cfs.compressionDictionaryManager(); + + assertThatNoException() + .as("Should allow manual training") + .isThrownBy(() -> manager.train(Map.of("maxSamplingDurationSeconds", "600"))); + + // Disable dictionary compression + CompressionParams nonDictParams = CompressionParams.lz4(); + manager.maybeReloadFromSchema(nonDictParams); + + assertThatThrownBy(() -> manager.train(Map.of("maxSamplingDurationSeconds", "600"))) + .as("Should disallow manual training when using lz4") + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("does not support dictionary compression"); + + // Re-enable dictionary compression + CompressionParams dictParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Collections.singletonMap("compression_level", "3")); + manager.maybeReloadFromSchema(dictParams); + + assertThatNoException() + .as("Should allow manual training after switching back to dictionary compression") + .isThrownBy(() -> manager.train(Map.of("maxSamplingDurationSeconds", "600"))); + } + + @Test + public void testCompressionParameterChanges() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(table); + CompressionDictionaryManager manager = cfs.compressionDictionaryManager(); + ICompressionDictionaryTrainer trainer = manager.trainer(); + assertThat(trainer).isNotNull(); + assertThat(trainer.kind()).isEqualTo(Kind.ZSTD); + + // Change compression level - should create new trainer + CompressionParams newParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Collections.singletonMap("compression_level", "5")); + manager.maybeReloadFromSchema(newParams); + ICompressionDictionaryTrainer newTrainer = manager.trainer(); + assertThat(newTrainer.kind()).isEqualTo(Kind.ZSTD); + assertThat(newTrainer) + .as("Should create a different trainer instance when compression level is changed") + .isNotSameAs(trainer); + } + + @Test + public void testSSTableCompressionWithDictionary() + { + DatabaseDescriptor.setFlushCompression(Config.FlushCompression.table); + String table = createTable("CREATE TABLE %s (pk text PRIMARY KEY, data text) " + + // use 4 KiB, so it collects enough samples. Trainer requires at least 10 samples + "WITH compression = {'class': 'ZstdDictionaryCompressor', 'chunk_length_in_kb' : 4}"); + ColumnFamilyStore cfs = Keyspace.open(keyspace()).getColumnFamilyStore(table); + CompressionDictionaryManager manager = cfs.compressionDictionaryManager(); + + manager.train(Map.of("maxSamplingDurationSeconds", "5")); + + // Insert compressible data to train dictionary + int i = 0; + while (!manager.isReady()) + { + int index = i++; + execute("INSERT INTO %s (pk, data) VALUES (?, ?)", + "key" + index, + REPEATED_DATA + " row " + index); + if (i % 200 == 0) + flush(); + } + flush(); + + // training should finish in 3 seconds and have the dictionary available + spinUntilTrue(() -> manager.getCurrent() != null, 3); + + // Insert compressible data to be compressed by dictionary + for (int j = i; j < i + 500; j++) + { + execute("INSERT INTO %s (pk, data) VALUES (?, ?)", + "key" + j, + REPEATED_DATA + " row " + j); + } + + // Verify SSTable was created with compression + assertThat(cfs.getLiveSSTables()) + .as("Should have created SSTables") + .hasSizeGreaterThan(1); + + SSTableReader sstable = cfs.getLiveSSTables().iterator().next(); + assertThat(sstable.compression) + .as("SSTable should have compression parameters") + .isNotNull(); + + // Verify data can be read back correctly + // - Can read data from the sstable w/o dictionary + assertRows(execute("SELECT pk, data FROM %s WHERE pk = ?", "key0"), + row("key0", REPEATED_DATA + " row 0")); + // - Can read data from the sstable w/ dictionary + int rowInDictSSTable = i + 100; + assertRows(execute("SELECT pk, data FROM %s WHERE pk = ?", "key" + rowInDictSSTable), + row("key" + rowInDictSSTable, REPEATED_DATA + " row " + rowInDictSSTable)); + } + + @Test + public void testResourceCleanupOnClose() throws Exception + { + createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + ColumnFamilyStore cfs = getCurrentColumnFamilyStore(); + CompressionDictionaryManager manager = cfs.compressionDictionaryManager(); + + // Add test dictionary + ZstdCompressionDictionary testDict = createTestDictionary(); + manager.add(testDict); + manager.setCurrentIfNewer(testDict); + + assertThat(testDict.selfRef().globalCount()) + .as("Dictionary's reference count should be 1 after adding to cache") + .isOne(); + + assertThat(manager.getCurrent()) + .as("Should have current dictionary before close") + .isNotNull(); + + manager.close(); + + assertThat(manager.trainer()).isNull(); + assertThat(testDict.selfRef().globalCount()) + .as("Dictionary's reference count should be 0 after closing manager") + .isZero(); + assertThat(testDict.rawDictionary()) + .as("The raw dictionary bytes should still be accessible") + .isNotNull(); + } + + private static ZstdCompressionDictionary createTestDictionary() + { + byte[] dictBytes = (REPEATED_DATA + " dictionary training data").getBytes(); + DictId dictId = new DictId(Kind.ZSTD, System.currentTimeMillis()); + return new ZstdCompressionDictionary(dictId, dictBytes); + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryManagerMBeanTest.java b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryManagerMBeanTest.java new file mode 100644 index 000000000000..fbd56e2d2221 --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryManagerMBeanTest.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.Map; + +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.SchemaLoader; +import org.apache.cassandra.ServerTestUtils; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.schema.CompressionParams; +import org.apache.cassandra.schema.KeyspaceParams; +import org.apache.cassandra.schema.TableMetadata; +import org.apache.cassandra.utils.MBeanWrapper; +import org.apache.cassandra.utils.MBeanWrapper.OnException; + +import static org.assertj.core.api.Assertions.assertThat; + +public class CompressionDictionaryManagerMBeanTest +{ + private static final String KEYSPACE_WITH_DICT = "keyspace_mbean_test"; + private static final String TABLE = "test_table"; + + private static ColumnFamilyStore cfsWithDict; + + @BeforeClass + public static void setUpClass() throws Exception + { + ServerTestUtils.prepareServer(); + CompressionParams compressionParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "3")); + TableMetadata.Builder tableBuilder = TableMetadata.builder(KEYSPACE_WITH_DICT, TABLE) + .addPartitionKeyColumn("pk", org.apache.cassandra.db.marshal.UTF8Type.instance) + .addRegularColumn("data", org.apache.cassandra.db.marshal.UTF8Type.instance) + .compression(compressionParams); + SchemaLoader.createKeyspace(KEYSPACE_WITH_DICT, + KeyspaceParams.simple(1), + tableBuilder); + cfsWithDict = Keyspace.open(KEYSPACE_WITH_DICT).getColumnFamilyStore(TABLE); + } + + // Ensure no mbean is registered at the begining of the test + @Before + public void cleanup() + { + String mbeanName = CompressionDictionaryManager.mbeanName(KEYSPACE_WITH_DICT, TABLE); + MBeanWrapper.instance.unregisterMBean(mbeanName, OnException.IGNORE); + } + + @Test + public void testMBeanRegisteredWhenBookkeepingEnabled() + { + String mbeanName = CompressionDictionaryManager.mbeanName(KEYSPACE_WITH_DICT, TABLE); + // Create manager with bookkeeping enabled + try (CompressionDictionaryManager manager = new CompressionDictionaryManager(cfsWithDict, true)) + { + // Verify MBean is registered + assertThat(MBeanWrapper.instance.isRegistered(mbeanName)) + .as("MBean should be registered when bookkeeping is enabled") + .isTrue(); + } + // Closing manager should unregister the mbean; Verify it is unregistered + assertThat(MBeanWrapper.instance.isRegistered(mbeanName)) + .as("MBean should be unregistered after unregisterMbean() call") + .isFalse(); + } + + @Test + public void testMBeanNotRegisteredWhenBookkeepingDisabled() + { + // Create manager with bookkeeping disabled + try (CompressionDictionaryManager manager = new CompressionDictionaryManager(cfsWithDict, false)) + { + // Verify MBean is NOT registered + String mbeanName = CompressionDictionaryManager.mbeanName(KEYSPACE_WITH_DICT, TABLE);; + assertThat(MBeanWrapper.instance.isRegistered(mbeanName)) + .as("MBean should not be registered when bookkeeping is disabled") + .isFalse(); + } + // Closing manager should not throw due to mbean not registered + } + + @Test + public void testMBeanUnregisteredOnCFSInvalidation() + { + String testKeyspace = "test_invalidation_mbean_ks"; + String testTable = "test_invalidation_mbean_table"; + + CompressionParams compressionParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "3")); + + TableMetadata.Builder tableBuilder = TableMetadata.builder(testKeyspace, testTable) + .addPartitionKeyColumn("pk", org.apache.cassandra.db.marshal.UTF8Type.instance) + .addRegularColumn("data", org.apache.cassandra.db.marshal.UTF8Type.instance) + .compression(compressionParams); + + SchemaLoader.createKeyspace(testKeyspace, + KeyspaceParams.simple(1), + tableBuilder); + + ColumnFamilyStore cfs = Keyspace.open(testKeyspace).getColumnFamilyStore(testTable); + + String mbeanName = CompressionDictionaryManager.mbeanName(testKeyspace, testTable); + + // Verify MBean is registered (CFS registers it during creation) + assertThat(MBeanWrapper.instance.isRegistered(mbeanName)) + .as("MBean should be registered after CFS creation") + .isTrue(); + + // Invalidate the CFS (which should unregister the MBean) + cfs.invalidate(true, true); + + // Verify MBean is unregistered + assertThat(MBeanWrapper.instance.isRegistered(mbeanName)) + .as("MBean should be unregistered after CFS invalidation") + .isFalse(); + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryManagerTest.java b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryManagerTest.java new file mode 100644 index 000000000000..51a8ce664dbc --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryManagerTest.java @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.nio.ByteBuffer; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.SchemaLoader; +import org.apache.cassandra.ServerTestUtils; +import org.apache.cassandra.config.CassandraRelevantProperties; +import org.apache.cassandra.db.ColumnFamilyStore; +import org.apache.cassandra.db.Keyspace; +import org.apache.cassandra.db.compression.ICompressionDictionaryTrainer.TrainingStatus; +import org.apache.cassandra.schema.CompressionParams; +import org.apache.cassandra.schema.KeyspaceParams; +import org.apache.cassandra.schema.TableMetadata; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class CompressionDictionaryManagerTest +{ + private static final String KEYSPACE_WITH_DICT = "keyspace_with_dict"; + private static final String KEYSPACE_WITHOUT_DICT = "keyspace_without_dict"; + private static final String TABLE = "test_table"; + + private static ColumnFamilyStore cfsWithDict; + private static ColumnFamilyStore cfsWithoutDict; + + private CompressionDictionaryManager managerWithDict; + private CompressionDictionaryManager managerWithoutDict; + + @BeforeClass + public static void setUpClass() throws Exception + { + CassandraRelevantProperties.ORG_APACHE_CASSANDRA_DISABLE_MBEAN_REGISTRATION.setBoolean(true); + ServerTestUtils.prepareServerNoRegister(); + + // Create table with dictionary compression enabled + CompressionParams compressionParamsWithDict = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "3")); + + TableMetadata.Builder tableBuilderWithDict = TableMetadata.builder(KEYSPACE_WITH_DICT, TABLE) + .addPartitionKeyColumn("pk", org.apache.cassandra.db.marshal.UTF8Type.instance) + .addRegularColumn("data", org.apache.cassandra.db.marshal.UTF8Type.instance) + .compression(compressionParamsWithDict); + + // Create table without dictionary compression + CompressionParams compressionParamsWithoutDict = CompressionParams.lz4(); + + TableMetadata.Builder tableBuilderWithoutDict = TableMetadata.builder(KEYSPACE_WITHOUT_DICT, TABLE) + .addPartitionKeyColumn("pk", org.apache.cassandra.db.marshal.UTF8Type.instance) + .addRegularColumn("data", org.apache.cassandra.db.marshal.UTF8Type.instance) + .compression(compressionParamsWithoutDict); + + SchemaLoader.createKeyspace(KEYSPACE_WITH_DICT, + KeyspaceParams.simple(1), + tableBuilderWithDict); + + SchemaLoader.createKeyspace(KEYSPACE_WITHOUT_DICT, + KeyspaceParams.simple(1), + tableBuilderWithoutDict); + + cfsWithDict = Keyspace.open(KEYSPACE_WITH_DICT).getColumnFamilyStore(TABLE); + cfsWithoutDict = Keyspace.open(KEYSPACE_WITHOUT_DICT).getColumnFamilyStore(TABLE); + } + + @Before + public void setUp() + { + managerWithDict = new CompressionDictionaryManager(cfsWithDict, true); + managerWithoutDict = new CompressionDictionaryManager(cfsWithoutDict, true); + } + + @After + public void tearDown() throws Exception + { + if (managerWithDict != null) + { + managerWithDict.close(); + } + if (managerWithoutDict != null) + { + managerWithoutDict.close(); + } + } + + @Test + public void testManagerInitializationWithDictionaryCompression() + { + assertThat(managerWithDict) + .as("Manager should be created successfully for dictionary-enabled table") + .isNotNull(); + + // Manager should start in a valid state + String status = managerWithDict.getTrainingStatus(); + assertThat(status) + .as("Training status should be valid") + .isEqualTo(TrainingStatus.NOT_STARTED.toString()); + } + + @Test + public void testManagerInitializationWithoutDictionaryCompression() + { + assertThat(managerWithoutDict) + .as("Manager should be created successfully for non-dictionary table") + .isNotNull(); + + // Should report NOT_STARTED since no trainer is created + String status = managerWithoutDict.getTrainingStatus(); + assertThat(status) + .as("Should report NOT_STARTED for non-dictionary tables") + .isEqualTo(TrainingStatus.NOT_STARTED.toString()); + } + + @Test + public void testMaybeReloadFromSchemaEnableDictionaryCompression() + { + // Start with manager for non-dictionary table + String initialStatus = managerWithoutDict.getTrainingStatus(); + assertThat(initialStatus) + .as("Initially should not be training") + .isEqualTo(TrainingStatus.NOT_STARTED.toString()); + + // Enable dictionary compression by switching to dict params + CompressionParams dictParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "3")); + + managerWithoutDict.maybeReloadFromSchema(dictParams); + + managerWithoutDict.train(Map.of("maxSamplingDurationSeconds", "600")); + // Should now have training capability + String newStatus = managerWithoutDict.getTrainingStatus(); + assertThat(newStatus) + .as("Should now support training") + .isEqualTo(TrainingStatus.SAMPLING.toString()); + } + + @Test + public void testMaybeReloadFromSchemaDisableDictionaryCompression() + { + managerWithDict.train(Map.of("maxSamplingDurationSeconds", "600")); + String status = managerWithDict.getTrainingStatus(); + assertThat(status) + .as("Should be sampling") + .isEqualTo(TrainingStatus.SAMPLING.toString()); + + // Disable dictionary compression + CompressionParams nonDictParams = CompressionParams.lz4(); + managerWithDict.maybeReloadFromSchema(nonDictParams); + + // Should disable training + String newStatus = managerWithDict.getTrainingStatus(); + assertThat(newStatus) + .as("Should disable training when dictionary compression is disabled") + .isEqualTo(TrainingStatus.NOT_STARTED.toString()); + } + + @Test + public void testTrainerCompatibilityCheck() + { + managerWithDict.train(Map.of("maxSamplingDurationSeconds", "600")); + String initialStatus = managerWithDict.getTrainingStatus(); + assertThat(initialStatus) + .as("Should be sampling") + .isEqualTo(TrainingStatus.SAMPLING.toString()); + + // Change compression level - should create new trainer + CompressionParams differentLevelParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "5")); + managerWithDict.maybeReloadFromSchema(differentLevelParams); + String newStatus = managerWithDict.getTrainingStatus(); + + // Status should reset due to trainer replacement + assertThat(newStatus) + .as("Should reset status when creating new trainer") + .isEqualTo(TrainingStatus.NOT_STARTED.toString()); + } + + @Test + public void testAddSample() + { + ByteBuffer sample = ByteBuffer.wrap("test sample data".getBytes()); + ByteBuffer emptyBuffer = ByteBuffer.allocate(0); + + // Should not throw for dictionary-enabled table + assertThatNoException().isThrownBy(() -> managerWithDict.addSample(sample)); + assertThatNoException().isThrownBy(() -> managerWithDict.addSample(null)); + assertThatNoException().isThrownBy(() -> managerWithDict.addSample(emptyBuffer)); + // Should not throw for non-dictionary table (graceful handling) + assertThatNoException().isThrownBy(() -> managerWithoutDict.addSample(sample)); + assertThatNoException().isThrownBy(() -> managerWithoutDict.addSample(null)); + assertThatNoException().isThrownBy(() -> managerWithoutDict.addSample(emptyBuffer)); + } + + @Test + public void testTrainManualWithNonDictionaryTable() + { + assertThatThrownBy(() -> managerWithoutDict.train(Map.of("maxSamplingDurationSeconds", "600"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("does not support dictionary compression"); + } + + @Test + public void testTrainManualWithMissingParameters() + { + assertThatThrownBy(() -> managerWithDict.train(Map.of())) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("maxSamplingDurationSeconds parameter is required"); + + assertThatThrownBy(() -> managerWithDict.train(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("maxSamplingDurationSeconds parameter is required"); + } + + @Test + public void testTrainManualWithInvalidParameters() + { + assertThatThrownBy(() -> managerWithDict.train(Map.of("maxSamplingDurationSeconds", "invalid"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Invalid maxSamplingDurationSeconds value: invalid") + .hasCauseInstanceOf(NumberFormatException.class); + + assertThatThrownBy(() -> managerWithDict.train(Map.of("maxSamplingDurationSeconds", "-1"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("maxSamplingDurationSeconds must be positive, got: -1"); + } + + @Test + public void testTrainManualWithOptions() + { + // Should accept custom options + managerWithDict.train(Map.of("maxSamplingDurationSeconds", "30")); + + String status = managerWithDict.getTrainingStatus(); + assertThat(status) + .as("Training with options should work") + .isEqualTo(TrainingStatus.SAMPLING.toString()); + } + + @Test + public void testSchemaChangeWorkflow() + { + // Start with non-dictionary table + String initialStatus = managerWithoutDict.getTrainingStatus(); + assertThat(initialStatus).isEqualTo(TrainingStatus.NOT_STARTED.toString()); + + // Enable dictionary compression + CompressionParams dictParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "3")); + managerWithoutDict.maybeReloadFromSchema(dictParams); + managerWithoutDict.train(Map.of("maxSamplingDurationSeconds", "600")); + // Should now support training + String enabledStatus = managerWithoutDict.getTrainingStatus(); + assertThat(enabledStatus).isEqualTo(TrainingStatus.SAMPLING.toString()); + + // Change compression level + CompressionParams newDictParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "5")); + managerWithoutDict.maybeReloadFromSchema(newDictParams); + + // Should still support training with new parameters + String updatedStatus = managerWithoutDict.getTrainingStatus(); + assertThat(updatedStatus).isEqualTo(TrainingStatus.NOT_STARTED.toString()); + managerWithoutDict.train(Map.of("maxSamplingDurationSeconds", "600")); + assertThat(enabledStatus).isEqualTo(TrainingStatus.SAMPLING.toString()); + + // Disable dictionary compression + CompressionParams nonDictParams = CompressionParams.lz4(); + managerWithoutDict.maybeReloadFromSchema(nonDictParams); + + // Should disable training + String disabledStatus = managerWithoutDict.getTrainingStatus(); + assertThat(disabledStatus).isEqualTo(TrainingStatus.NOT_STARTED.toString()); + } + + @Test + public void testUpdateSamplingRate() + { + // Test with enabled dictionary manager + managerWithDict.train(Map.of("maxSamplingDurationSeconds", "600")); + + // Should be able to update sampling rate + assertThatNoException().isThrownBy(() -> managerWithDict.updateSamplingRate(5)); + assertThatNoException().isThrownBy(() -> managerWithDict.updateSamplingRate(1)); + assertThatNoException().isThrownBy(() -> managerWithDict.updateSamplingRate(100)); + } + + @Test + public void testUpdateSamplingRateWithoutTrainer() + { + // Test with disabled dictionary manager (no trainer) + assertThatThrownBy(() -> managerWithoutDict.updateSamplingRate(5)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Dictionary trainer is not available"); + } + + @Test + public void testUpdateSamplingRateValidation() + { + // Test with enabled dictionary manager + managerWithDict.train(Map.of("maxSamplingDurationSeconds", "600")); + + // Test invalid sampling rates are rejected by the trainer + assertThatThrownBy(() -> managerWithDict.updateSamplingRate(0)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Sampling rate must be positive"); + + assertThatThrownBy(() -> managerWithDict.updateSamplingRate(-1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Sampling rate must be positive"); + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/CompressionDictionarySchedulerTest.java b/test/unit/org/apache/cassandra/db/compression/CompressionDictionarySchedulerTest.java new file mode 100644 index 000000000000..d1099aa22f16 --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/CompressionDictionarySchedulerTest.java @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.SchemaLoader; +import org.apache.cassandra.ServerTestUtils; +import org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import org.apache.cassandra.db.compression.ICompressionDictionaryTrainer.TrainingStatus; +import org.apache.cassandra.schema.CompressionParams; +import org.apache.cassandra.schema.KeyspaceParams; + +import static org.apache.cassandra.Util.spinUntilTrue; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class CompressionDictionarySchedulerTest +{ + private static final String TEST_NAME = "compression_dict_scheduler_test_"; + private static final String KEYSPACE = TEST_NAME + "keyspace"; + private static final String TABLE = "test_table"; + + private CompressionDictionaryScheduler scheduler; + private TestDictionaryTrainer testTrainer; + private ZstdCompressionDictionary testDictionary; + private ICompressionDictionaryCache testCache; + + @BeforeClass + public static void setUpClass() throws Exception + { + ServerTestUtils.prepareServerNoRegister(); + SchemaLoader.createKeyspace(KEYSPACE, KeyspaceParams.simple(1)); + } + + @Before + public void setUp() + { + testTrainer = new TestDictionaryTrainer(); + testDictionary = createTestDictionary(); + testCache = new CompressionDictionaryCache(); + scheduler = new CompressionDictionaryScheduler(KEYSPACE, TABLE, testCache, true); + } + + @After + public void tearDown() throws Exception + { + if (scheduler != null) + { + scheduler.close(); + } + if (testDictionary != null) + { + testDictionary.close(); + } + if (testCache != null) + { + testCache.close(); + } + } + + @Test + public void testScheduleManualTraining() + { + testManualTraining(false, new ManualTrainingOptions(600)); + } + + @Test + public void testScheduleManualTrainingWithCustomDuration() + { + testManualTraining(true, new ManualTrainingOptions(1)); + } + + @Test + public void testConcurrentTraining() + { + ManualTrainingOptions options = new ManualTrainingOptions(600); + + testTrainer.setReady(true); + testTrainer.setTrainingResult(CompletableFuture.completedFuture(testDictionary)); + + // Schedule first training + scheduler.scheduleManualTraining(options, testTrainer); + + // Attempt to schedule second training should fail + assertThatThrownBy(() -> scheduler.scheduleManualTraining(options, testTrainer)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Training already in progress"); + } + + @Test + public void testManualTrainingFailure() + { + ManualTrainingOptions options = new ManualTrainingOptions(600); + + testTrainer.setReady(true); + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(new RuntimeException("Training failed")); + testTrainer.setTrainingResult(failedFuture); + + scheduler.scheduleManualTraining(options, testTrainer); + + // Expect the trainer to fail + spinUntilTrue(() -> testTrainer.getTrainingStatus() == TrainingStatus.FAILED, 5); + } + + @Test + public void testTrainerNotStarted() + { + ManualTrainingOptions options = new ManualTrainingOptions(600); + + testTrainer.setTrainingStatus(TrainingStatus.NOT_STARTED); + + scheduler.scheduleManualTraining(options, testTrainer); + assertThat((Object) scheduler.scheduledManualTrainingTask()).isNotNull(); + + // Expect the manual training task to be cleaned up + spinUntilTrue(() -> scheduler.scheduledManualTrainingTask() == null, 5); + } + + private void testManualTraining(boolean expectForceTraining, ManualTrainingOptions trainOptions) + { + boolean ready = !expectForceTraining; + testTrainer.setReady(ready); + testTrainer.setTrainingResult(CompletableFuture.completedFuture(testDictionary)); + AtomicReference dictHolder = new AtomicReference<>(); + testTrainer.setDictionaryTrainedListener(dictHolder::set); + + assertThat(dictHolder.get()) + .as("No dictionary is available before training") + .isNull(); + + scheduler.scheduleManualTraining(trainOptions, testTrainer); + + // Wait until dictionary is trained and notified + spinUntilTrue(() -> dictHolder.get() == testDictionary, 5); + assertThat(testTrainer.isForceTrained).isEqualTo(expectForceTraining); + + assertThat(testTrainer.getTrainDictionaryAsyncCallCount()) + .as("trainDictionaryAsync should be called") + .isGreaterThan(0); + } + + private static ZstdCompressionDictionary createTestDictionary() + { + byte[] dictBytes = "test dictionary data for scheduler testing".getBytes(); + DictId dictId = new DictId(Kind.ZSTD, System.currentTimeMillis()); + return new ZstdCompressionDictionary(dictId, dictBytes); + } + + /** + * Test implementation of dictionary trainer + */ + private static class TestDictionaryTrainer implements ICompressionDictionaryTrainer + { + public volatile boolean isForceTrained = false; + private final AtomicInteger trainDictionaryAsyncCallCount = new AtomicInteger(0); + private volatile TrainingStatus trainingStatus = TrainingStatus.SAMPLING; + private volatile boolean ready = false; + private volatile CompletableFuture trainingResult = null; + private volatile Consumer onDictionaryTrained = null; + + @Override + public boolean shouldSample() + { + return true; + } + + @Override + public void addSample(java.nio.ByteBuffer sample) + { + // No-op for testing + } + + @Override + public CompressionDictionary trainDictionary(boolean force) + { + throw new RuntimeException("Not expected to be called in test"); + } + + @Override + public CompletableFuture trainDictionaryAsync(boolean force) + { + trainDictionaryAsyncCallCount.incrementAndGet(); + isForceTrained = force; + if (trainingResult != null) + { + if (trainingResult.isCompletedExceptionally()) + { + trainingStatus = TrainingStatus.FAILED; + } + else + { + trainingStatus = TrainingStatus.COMPLETED; + try + { + onDictionaryTrained.accept(trainingResult.get()); + } + catch (Exception e) + { + throw new RuntimeException(e); + } + } + return trainingResult; + } + + return CompletableFuture.completedFuture(createTestDictionary()); + } + + @Override + public boolean isReady() + { + return ready; + } + + @Override + public void reset() + { + trainingStatus = TrainingStatus.NOT_STARTED; + ready = false; + } + + @Override + public TrainingStatus getTrainingStatus() + { + return trainingStatus; + } + + @Override + public boolean start(boolean manualTraining) + { + if (trainingStatus == TrainingStatus.NOT_STARTED) + { + trainingStatus = TrainingStatus.SAMPLING; + return true; + } + return false; + } + + @Override + public Kind kind() + { + return Kind.ZSTD; + } + + @Override + public boolean isCompatibleWith(CompressionParams newParams) + { + return true; // Simplified for testing + } + + @Override + public void close() + { + trainingStatus = TrainingStatus.NOT_STARTED; + } + + @Override + public void setDictionaryTrainedListener(Consumer listener) + { + this.onDictionaryTrained = listener; + } + + @Override + public void updateSamplingRate(int newSamplingRate) + { + // not used in test + } + + // Test helper methods + public void setReady(boolean ready) + { + this.ready = ready; + } + + public void setTrainingStatus(TrainingStatus status) + { + this.trainingStatus = status; + } + + public void setTrainingResult(CompletableFuture result) + { + this.trainingResult = result; + } + + public int getTrainDictionaryAsyncCallCount() + { + return trainDictionaryAsyncCallCount.get(); + } + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryTrainingConfigTest.java b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryTrainingConfigTest.java new file mode 100644 index 000000000000..29e7d361a000 --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/CompressionDictionaryTrainingConfigTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import org.junit.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +public class CompressionDictionaryTrainingConfigTest +{ + @Test + public void testBuilderDefaults() + { + CompressionDictionaryTrainingConfig config = CompressionDictionaryTrainingConfig.builder().build(); + + assertThat(config.maxDictionarySize) + .as("Default max dictionary size should be 64KB") + .isEqualTo(65536); + assertThat(config.maxTotalSampleSize) + .as("Default max total sample size should be 10MB") + .isEqualTo(10 * 1024 * 1024); + assertThat(config.samplingRate) + .as("Default sampling rate should be 100 (1%)") + .isEqualTo(100); + } + + @Test + public void testCalculatedThresholds() + { + int dictSize = 16 * 1024; // 16KB + int sampleSize = 2 * 1024 * 1024; // 2MB + int samplingRate = 200; // 0.5% + + CompressionDictionaryTrainingConfig config = CompressionDictionaryTrainingConfig.builder() + .maxDictionarySize(dictSize) + .maxTotalSampleSize(sampleSize) + .samplingRate(samplingRate) + .build(); + + // Verify all calculated values are consistent + assertThat(config.maxDictionarySize).isEqualTo(dictSize); + assertThat(config.maxTotalSampleSize).isEqualTo(sampleSize); + assertThat(config.acceptableTotalSampleSize).isEqualTo(sampleSize / 10 * 8); + assertThat(config.samplingRate).isEqualTo(samplingRate); + + // Verify relationship between max and acceptable sample sizes + assertThat(config.acceptableTotalSampleSize) + .as("Acceptable sample size should be less than or equal to max") + .isLessThanOrEqualTo(config.maxTotalSampleSize); + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/ManualTrainingOptionsTest.java b/test/unit/org/apache/cassandra/db/compression/ManualTrainingOptionsTest.java new file mode 100644 index 000000000000..f845037bb9dd --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/ManualTrainingOptionsTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.util.Map; + +import org.junit.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class ManualTrainingOptionsTest +{ + @Test + public void testValidConstruction() + { + ManualTrainingOptions options = new ManualTrainingOptions(600); + assertThat(options.getMaxSamplingDurationSeconds()).isEqualTo(600); + } + + @Test + public void testInvalidDurationThrows() + { + assertThatThrownBy(() -> new ManualTrainingOptions(0)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("maxSamplingDurationSeconds must be positive, got: 0"); + + assertThatThrownBy(() -> new ManualTrainingOptions(-1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("maxSamplingDurationSeconds must be positive, got: -1"); + } + + @Test + public void testFromStringMapValid() + { + Map options = Map.of("maxSamplingDurationSeconds", "300"); + ManualTrainingOptions trainingOptions = ManualTrainingOptions.fromStringMap(options); + + assertThat(trainingOptions.getMaxSamplingDurationSeconds()).isEqualTo(300); + } + + @Test + public void testFromStringMapMissingKey() + { + Map emptyOptions = Map.of(); + + assertThatThrownBy(() -> ManualTrainingOptions.fromStringMap(emptyOptions)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("maxSamplingDurationSeconds parameter is required"); + + assertThatThrownBy(() -> ManualTrainingOptions.fromStringMap(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("maxSamplingDurationSeconds parameter is required"); + } + + @Test + public void testFromStringMapInvalidValue() + { + Map invalidOptions = Map.of("maxSamplingDurationSeconds", "invalid"); + + assertThatThrownBy(() -> ManualTrainingOptions.fromStringMap(invalidOptions)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Invalid maxSamplingDurationSeconds value: invalid") + .hasCauseInstanceOf(NumberFormatException.class); + } + + @Test + public void testFromStringMapNegativeValue() + { + Map negativeOptions = Map.of("maxSamplingDurationSeconds", "-1"); + + assertThatThrownBy(() -> ManualTrainingOptions.fromStringMap(negativeOptions)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("maxSamplingDurationSeconds must be positive, got: -1"); + } + + @Test + public void testFromStringMapZeroValue() + { + Map zeroOptions = Map.of("maxSamplingDurationSeconds", "0"); + + assertThatThrownBy(() -> ManualTrainingOptions.fromStringMap(zeroOptions)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("maxSamplingDurationSeconds must be positive, got: 0"); + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/ZstdCompressionDictionaryTest.java b/test/unit/org/apache/cassandra/db/compression/ZstdCompressionDictionaryTest.java new file mode 100644 index 000000000000..c47a0bb00eca --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/ZstdCompressionDictionaryTest.java @@ -0,0 +1,392 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import com.github.luben.zstd.ZstdDictCompress; +import com.github.luben.zstd.ZstdDictDecompress; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import org.apache.cassandra.io.compress.ZstdCompressorBase; +import org.apache.cassandra.utils.concurrent.Ref; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class ZstdCompressionDictionaryTest +{ + private static final byte[] SAMPLE_DICT_DATA = createSampleDictionaryData(); + private static final DictId SAMPLE_DICT_ID = new DictId(Kind.ZSTD, 123456789L); + + private ZstdCompressionDictionary dictionary; + + @BeforeClass + public static void setUpClass() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Before + public void setUp() + { + dictionary = new ZstdCompressionDictionary(SAMPLE_DICT_ID, SAMPLE_DICT_DATA); + } + + @Test + public void testEqualsAndHashCode() + { + ZstdCompressionDictionary dictionary2 = new ZstdCompressionDictionary(SAMPLE_DICT_ID, SAMPLE_DICT_DATA); + ZstdCompressionDictionary differentIdDict = new ZstdCompressionDictionary( + new DictId(Kind.ZSTD, 987654321L), SAMPLE_DICT_DATA); + + assertThat(dictionary) + .as("Dictionaries with same ID should be equal") + .isEqualTo(dictionary2); + + assertThat(dictionary.hashCode()) + .as("Hash codes should be equal for same ID") + .isEqualTo(dictionary2.hashCode()); + + assertThat(dictionary) + .as("Dictionaries with different IDs should not be equal") + .isNotEqualTo(differentIdDict); + + dictionary2.close(); + differentIdDict.close(); + } + + @Test + public void testDictionaryForCompression() + { + int compressionLevel = 3; + ZstdDictCompress compressDict = dictionary.dictionaryForCompression(compressionLevel); + + assertThat(compressDict) + .as("Compression dictionary should not be null") + .isNotNull(); + + // Calling again should return the same cached instance + ZstdDictCompress compressDict2 = dictionary.dictionaryForCompression(compressionLevel); + assertThat(compressDict2) + .as("Second call should return cached instance") + .isSameAs(compressDict); + } + + @Test + public void testDictionaryForCompressionMultipleLevels() + { + ZstdDictCompress level1 = dictionary.dictionaryForCompression(1); + ZstdDictCompress level3 = dictionary.dictionaryForCompression(3); + ZstdDictCompress level6 = dictionary.dictionaryForCompression(6); + + assertThat(level1) + .as("Level 1 compression dictionary should not be null") + .isNotNull(); + + assertThat(level3) + .as("Level 3 compression dictionary should not be null") + .isNotNull(); + + assertThat(level6) + .as("Level 6 compression dictionary should not be null") + .isNotNull(); + + assertThat(level1) + .as("Different compression levels should have different instances") + .isNotSameAs(level3); + + assertThat(level3) + .as("Different compression levels should have different instances") + .isNotSameAs(level6); + } + + @Test + public void testDictionaryForDecompression() + { + ZstdDictDecompress decompressDict = dictionary.dictionaryForDecompression(); + + assertThat(decompressDict) + .as("Decompression dictionary should not be null") + .isNotNull(); + + ZstdDictDecompress decompressDict2 = dictionary.dictionaryForDecompression(); + assertThat(decompressDict2) + .as("Second call should return cached instance") + .isSameAs(decompressDict); + } + + @Test + public void testInvalidCompressionLevel() + { + // Test with various invalid compression levels + assertThatThrownBy(() -> dictionary.dictionaryForCompression(ZstdCompressorBase.FAST_COMPRESSION_LEVEL - 1)) + .as("Negative compression level should throw exception") + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("is invalid"); + + assertThatThrownBy(() -> dictionary.dictionaryForCompression(100)) + .as("Too high compression level should throw exception") + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("is invalid"); + } + + @Test + public void testDictionaryClose() + { + // Access some dictionaries first + dictionary.dictionaryForCompression(3); + dictionary.dictionaryForDecompression(); + + dictionary.close(); + + assertThatThrownBy(() -> dictionary.dictionaryForCompression(3)) + .as("Should throw exception when accessing closed dictionary") + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Dictionary has been closed"); + + assertThatThrownBy(() -> dictionary.dictionaryForDecompression()) + .as("Should throw exception when accessing closed dictionary") + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Dictionary has been closed"); + } + + @Test + public void testTryRef() + { + Ref ref = dictionary.tryRef(); + + assertThat(ref) + .as("tryRef should return non-null reference") + .isNotNull(); + + assertThat(ref.get()) + .as("Reference should point to same dictionary") + .isSameAs(dictionary); + + ref.release(); + } + + @Test + public void testMultipleReferences() + { + Ref ref1 = dictionary.ref(); + Ref ref2 = dictionary.ref(); + Ref ref3 = dictionary.tryRef(); + + assertThat(ref1.get()) + .as("All references should point to same dictionary") + .isSameAs(dictionary); + + assertThat(ref2.get()) + .as("All references should point to same dictionary") + .isSameAs(dictionary); + + assertThat(ref3.get()) + .as("All references should point to same dictionary") + .isSameAs(dictionary); + + // Dictionary should still be accessible + assertThat(dictionary.dictionaryForCompression(3)) + .as("Dictionary should still be accessible with multiple refs") + .isNotNull(); + + ref1.release(); + ref2.release(); + ref3.release(); + } + + @Test + public void testReferenceAfterClose() + { + dictionary.close(); + + assertThatThrownBy(() -> dictionary.ref()) + .as("Should not be able to get reference after close") + .isInstanceOf(AssertionError.class); + + Ref tryRef = dictionary.tryRef(); + assertThat(tryRef) + .as("tryRef should return null after close") + .isNull(); + } + + @Test + public void testConcurrentAccess() throws Exception + { + ExecutorService executor = Executors.newFixedThreadPool(4); + AtomicInteger successCount = new AtomicInteger(0); + int numTasks = 100; + + try + { + Future[] futures = new Future[numTasks]; + + for (int i = 0; i < numTasks; i++) + { + final int level = (i % 6) + 1; // Compression levels 1-6 + futures[i] = executor.submit(() -> { + try + { + Ref ref = dictionary.ref(); + ZstdDictCompress compressDict = ref.get().dictionaryForCompression(level); + ZstdDictDecompress decompressDict = ref.get().dictionaryForDecompression(); + + assertThat(compressDict).isNotNull(); + assertThat(decompressDict).isNotNull(); + + successCount.incrementAndGet(); + ref.release(); + } + catch (Exception e) + { + throw new RuntimeException(e); + } + }); + } + + // Wait for all tasks to complete + for (Future future : futures) + { + future.get(5, TimeUnit.SECONDS); + } + + assertThat(successCount.get()) + .as("All concurrent accesses should succeed") + .isEqualTo(numTasks); + } + finally + { + executor.shutdown(); + executor.awaitTermination(5, TimeUnit.SECONDS); + } + } + + @Test + public void testSerializeDeserialize() throws IOException + { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + + dictionary.serialize(dos); + dos.flush(); + + byte[] serializedData = baos.toByteArray(); + assertThat(serializedData.length) + .as("Serialized data should not be empty") + .isGreaterThan(0); + + // Deserialize + ByteArrayInputStream bais = new ByteArrayInputStream(serializedData); + DataInputStream dis = new DataInputStream(bais); + + CompressionDictionary deserializedDict = CompressionDictionary.deserialize(dis, null); + + assertThat(deserializedDict) + .as("Deserialized dictionary should not be null") + .isNotNull(); + + assertThat(deserializedDict.identifier()) + .as("Deserialized dictionary ID should match") + .isEqualTo(dictionary.identifier()); + + assertThat(deserializedDict.kind()) + .as("Deserialized dictionary kind should match") + .isEqualTo(dictionary.kind()); + + assertThat(deserializedDict.rawDictionary()) + .as("Deserialized dictionary data should match") + .isEqualTo(dictionary.rawDictionary()); + } + + @Test + public void testSerializeDeserializeWithManager() throws Exception + { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + + dictionary.serialize(dos); + dos.flush(); + + byte[] serializedData = baos.toByteArray(); + + // First deserialization should create and cache the dictionary + ByteArrayInputStream bais1 = new ByteArrayInputStream(serializedData); + DataInputStream dis1 = new DataInputStream(bais1); + CompressionDictionary dict1 = CompressionDictionary.deserialize(dis1, null); + + // Second deserialization should return cached instance + ByteArrayInputStream bais2 = new ByteArrayInputStream(serializedData); + DataInputStream dis2 = new DataInputStream(bais2); + CompressionDictionary dict2 = CompressionDictionary.deserialize(dis2, null); + + assertThat(dict1) + .as("Both deserializations should return identical dictionary") + .isNotNull() + .isEqualTo(dict2); + + dict1.close(); + dict2.close(); + } + + @Test + public void testDeserializeCorruptedData() throws IOException + { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos); + + // Write corrupted data (wrong checksum) + dos.writeByte(Kind.ZSTD.ordinal()); + dos.writeLong(SAMPLE_DICT_ID.id); + dos.writeInt(SAMPLE_DICT_DATA.length); + dos.write(SAMPLE_DICT_DATA); + dos.writeInt(0xDEADBEEF); // Wrong checksum + dos.flush(); + + byte[] corruptedData = baos.toByteArray(); + ByteArrayInputStream bais = new ByteArrayInputStream(corruptedData); + DataInputStream dis = new DataInputStream(bais); + + assertThatThrownBy(() -> CompressionDictionary.deserialize(dis, null)) + .as("Should throw exception for corrupted data") + .isInstanceOf(IOException.class) + .hasMessageContaining("checksum does not match"); + } + + private static byte[] createSampleDictionaryData() + { + // Create sample dictionary data that could be used for compression + String sampleText = "The quick brown fox jumps over the lazy dog. "; + return sampleText.repeat(100).getBytes(); + } +} diff --git a/test/unit/org/apache/cassandra/db/compression/ZstdDictionaryTrainerTest.java b/test/unit/org/apache/cassandra/db/compression/ZstdDictionaryTrainerTest.java new file mode 100644 index 000000000000..5238232b4cd8 --- /dev/null +++ b/test/unit/org/apache/cassandra/db/compression/ZstdDictionaryTrainerTest.java @@ -0,0 +1,640 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.db.compression; + +import java.nio.ByteBuffer; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.compression.ICompressionDictionaryTrainer.TrainingStatus; +import org.apache.cassandra.schema.CompressionParams; + +import static org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class ZstdDictionaryTrainerTest +{ + private static final String TEST_KEYSPACE = "test_ks"; + private static final String TEST_TABLE = "test_table"; + private static final String SAMPLE_DATA = "The quick brown fox jumps over the lazy dog. "; + private static final int COMPRESSION_LEVEL = 3; + + private CompressionDictionaryTrainingConfig testConfig; + private ZstdDictionaryTrainer trainer; + private Consumer mockCallback; + private AtomicReference callbackResult; + + @BeforeClass + public static void setUpClass() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Before + public void setUp() + { + testConfig = CompressionDictionaryTrainingConfig.builder() + .maxDictionarySize(1024) // Small for testing + .maxTotalSampleSize(10 * 1024) // 10KB total + .samplingRate(1) // 100% sampling for predictable tests + .build(); + + callbackResult = new AtomicReference<>(); + mockCallback = callbackResult::set; + + trainer = new ZstdDictionaryTrainer(TEST_KEYSPACE, TEST_TABLE, testConfig, COMPRESSION_LEVEL); + trainer.setDictionaryTrainedListener(mockCallback); + } + + @After + public void tearDown() throws Exception + { + if (trainer != null) + { + trainer.close(); + } + + // Clean up any dictionary created in callback + CompressionDictionary dict = callbackResult.get(); + if (dict != null) + { + dict.close(); + callbackResult.set(null); + } + } + + @Test + public void testTrainerInitialState() + { + assertThat(trainer.getTrainingStatus()) + .as("Initial status should be NOT_STARTED") + .isEqualTo(TrainingStatus.NOT_STARTED); + assertThat(trainer.isReady()) + .as("Should not be ready initially") + .isFalse(); + assertThat(trainer.kind()) + .as("Should return ZSTD kind") + .isEqualTo(Kind.ZSTD); + } + + @Test + public void testTrainerStart() + { + // Auto start depends on configuration - test both scenarios + boolean started = trainer.start(false); + if (started) + { + assertThat(trainer.getTrainingStatus()) + .as("Status should be SAMPLING if auto-start enabled") + .isEqualTo(TrainingStatus.SAMPLING); + } + else + { + assertThat(trainer.getTrainingStatus()) + .as("Status should remain NOT_STARTED if auto-start disabled") + .isEqualTo(TrainingStatus.NOT_STARTED); + } + } + + @Test + public void testTrainerStartManual() + { + assertThat(trainer.start(true)) + .as("Manual training should start successfully") + .isTrue(); + assertThat(trainer.getTrainingStatus()) + .as("Status should be SAMPLING after start") + .isEqualTo(TrainingStatus.SAMPLING); + assertThat(trainer.isReady()) + .as("Should not be ready immediately after start") + .isFalse(); + } + + @Test + public void testTrainerStartMultipleTimes() + { + assertThat(trainer.start(true)) + .as("First start (manual training) should succeed") + .isTrue(); + Object firstTrainer = trainer.trainer(); + assertThat(firstTrainer).isNotNull(); + assertThat(trainer.start(true)) + .as("Second start (manual training) should suceed and reset") + .isTrue(); + Object secondTrainer = trainer.trainer(); + assertThat(secondTrainer).isNotNull().isNotSameAs(firstTrainer); + assertThat(trainer.start(false)) + .as("Third start (not manual training) should fail") + .isFalse(); + } + + @Test + public void testTrainerCloseIdempotent() + { + trainer.start(true); + trainer.close(); + trainer.close(); // Should not throw + trainer.close(); // Should not throw + + assertThat(trainer.getTrainingStatus()) + .as("Status should remain NOT_STARTED after multiple closes") + .isEqualTo(TrainingStatus.NOT_STARTED); + } + + @Test + public void testTrainerReset() + { + trainer.start(true); + addSampleData(1000); // Add some samples + + assertThat(trainer.getSampleCount()) + .as("Should have samples before reset") + .isGreaterThan(0); + + trainer.reset(); + assertThat(trainer.getTrainingStatus()) + .as("Status should be NOT_STARTED after reset") + .isEqualTo(TrainingStatus.NOT_STARTED); + assertThat(trainer.getSampleCount()) + .as("Sample count should be 0 after reset") + .isEqualTo(0); + assertThat(trainer.isReady()) + .as("Should not be ready after reset") + .isFalse(); + } + + @Test + public void testStartAfterClose() + { + trainer.start(true); + trainer.close(); + + assertThat(trainer.start(true)) + .as("Should not start after close") + .isFalse(); + assertThat(trainer.getTrainingStatus()) + .as("Status should remain NOT_STARTED") + .isEqualTo(TrainingStatus.NOT_STARTED); + } + + @Test + public void testShouldSample() + { + trainer.start(true); + // With sampling rate 1 (100%), should always return true + for (int i = 0; i < 10; i++) + { + assertThat(trainer.shouldSample()) + .as("Should sample with rate 1") + .isTrue(); + } + } + + @Test + public void testShouldSampleWithLowRate() + { + // Test with lower sampling rate + CompressionDictionaryTrainingConfig lowSamplingConfig = + CompressionDictionaryTrainingConfig.builder() + .maxDictionarySize(1024) + .maxTotalSampleSize(10 * 1024) + .samplingRate(500) // 0.1% sampling + .build(); + + try (ZstdDictionaryTrainer lowSamplingTrainer = new ZstdDictionaryTrainer(TEST_KEYSPACE, TEST_TABLE, + lowSamplingConfig, COMPRESSION_LEVEL)) + { + lowSamplingTrainer.setDictionaryTrainedListener(mockCallback); + // With very low sampling rate, should mostly return false + int sampleCount = 0; + int iterations = 1000; + for (int i = 0; i < iterations; i++) + { + if (lowSamplingTrainer.shouldSample()) + { + sampleCount++; + } + } + + // Should be roughly 0.1% (1 out of 1000), allow some variance + assertThat(sampleCount) + .as("Sample rate should be low") + .isLessThan(iterations / 10); + } + } + + @Test + public void testAddSample() + { + trainer.start(true); + + assertThat(trainer.getSampleCount()) + .as("Initial sample count should be 0") + .isEqualTo(0); + + ByteBuffer sample = ByteBuffer.wrap(SAMPLE_DATA.getBytes()); + trainer.addSample(sample); + + assertThat(trainer.getSampleCount()) + .as("Sample count should be 1 after adding one sample") + .isEqualTo(1); + assertThat(trainer.getTrainingStatus()) + .as("Status should be SAMPLING") + .isEqualTo(TrainingStatus.SAMPLING); + assertThat(trainer.isReady()) + .as("Should not be ready with single small sample") + .isFalse(); + } + + @Test + public void testAddSampleBeforeStart() + { + // Should not accept samples before start + ByteBuffer sample = ByteBuffer.wrap(SAMPLE_DATA.getBytes()); + trainer.addSample(sample); + + assertThat(trainer.getTrainingStatus()) + .as("Status should remain NOT_STARTED") + .isEqualTo(TrainingStatus.NOT_STARTED); + assertThat(trainer.isReady()) + .as("Should not be ready") + .isFalse(); + } + + @Test + public void testAddSampleAfterClose() + { + trainer.start(true); + trainer.close(); + + ByteBuffer sample = ByteBuffer.wrap(SAMPLE_DATA.getBytes()); + trainer.addSample(sample); + + assertThat(trainer.getTrainingStatus()) + .as("Status should remain NOT_STARTED after close") + .isEqualTo(TrainingStatus.NOT_STARTED); + assertThat(trainer.isReady()) + .as("Should not be ready after close") + .isFalse(); + } + + @Test + public void testAddNullSample() + { + trainer.start(true); + trainer.addSample(null); // Should not throw + + assertThat(trainer.getTrainingStatus()) + .as("Status should remain SAMPLING") + .isEqualTo(TrainingStatus.SAMPLING); + assertThat(trainer.isReady()) + .as("Should not be ready with null sample") + .isFalse(); + } + + @Test + public void testAddEmptySample() + { + trainer.start(true); + ByteBuffer empty = ByteBuffer.allocate(0); + trainer.addSample(empty); // Should not throw + + assertThat(trainer.getTrainingStatus()) + .as("Status should remain SAMPLING") + .isEqualTo(TrainingStatus.SAMPLING); + assertThat(trainer.isReady()) + .as("Should not be ready with empty sample") + .isFalse(); + } + + @Test + public void testIsReady() + { + trainer.start(true); + assertThat(trainer.isReady()) + .as("Should not be ready initially") + .isFalse(); + + addSampleData(testConfig.acceptableTotalSampleSize / 2); + assertThat(trainer.isReady()) + .as("Should not be ready with insufficient samples") + .isFalse(); + + addSampleData(testConfig.acceptableTotalSampleSize); + assertThat(trainer.isReady()) + .as("Should be ready after enough samples") + .isTrue(); + + trainer.close(); + + assertThat(trainer.isReady()) + .as("Should not be ready when closed") + .isFalse(); + } + + @Test + public void testTrainDictionaryWithInsufficientSampleCount() + { + trainer.start(true); + + // Add sufficient data size but only 5 samples (less than minimum 10) + for (int i = 0; i < 5; i++) + { + ByteBuffer largeSample = ByteBuffer.wrap(new byte[testConfig.acceptableTotalSampleSize / 5]); + trainer.addSample(largeSample); + } + + assertThat(trainer.getSampleCount()) + .as("Should have 5 samples") + .isEqualTo(5); + assertThat(trainer.isReady()) + .as("Should not be ready with insufficient sample count") + .isFalse(); + + // Trying to train without force should fail + assertThatThrownBy(() -> trainer.trainDictionary(false)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Trainer is not ready"); + + // Force training should fail with insufficient samples + assertThatThrownBy(() -> trainer.trainDictionary(true)) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("Insufficient samples for training: 5 (minimum required: 10)"); + } + + @Test + public void testTrainDictionaryWithSufficientSampleCount() + { + trainer.start(true); + + // Add 15 samples with sufficient total size + for (int i = 0; i < 15; i++) + { + ByteBuffer sample = ByteBuffer.wrap(new byte[testConfig.acceptableTotalSampleSize / 15 + 1]); + trainer.addSample(sample); + } + + assertThat(trainer.getSampleCount()).isEqualTo(15); + assertThat(trainer.isReady()).isTrue(); + + // Training should succeed + CompressionDictionary dictionary = trainer.trainDictionary(false); + assertThat(dictionary).as("Dictionary should be created").isNotNull(); + assertThat(trainer.getTrainingStatus()).isEqualTo(TrainingStatus.COMPLETED); + } + + @Test + public void testTrainDictionaryAsync() throws Exception + { + CompletableFuture future = startTraining(true, false, testConfig.acceptableTotalSampleSize); + CompressionDictionary dictionary = future.get(5, TimeUnit.SECONDS); + + assertThat(dictionary).as("Dictionary should not be null").isNotNull(); + assertThat(trainer.getTrainingStatus()).as("Status should be COMPLETED").isEqualTo(TrainingStatus.COMPLETED); + + // Verify callback was called + assertThat(callbackResult.get()).as("Callback should have been called").isNotNull(); + assertThat(callbackResult.get().identifier()).as("Callback should receive same dictionary").isEqualTo(dictionary.identifier()); + } + + @Test + public void testTrainDictionaryAsyncForce() throws Exception + { + // Don't add enough samples + CompletableFuture future = startTraining(true, true, 512); + CompressionDictionary dictionary = future.get(1, TimeUnit.SECONDS); + assertThat(dictionary) + .as("Forced async training should produce dictionary") + .isNotNull(); + } + + @Test + public void testTrainDictionaryAsyncForceFailsWithNoData() throws Exception + { + AtomicReference dictRef = new AtomicReference<>(); + CompletableFuture result = startTraining(true, true, 0) + .whenComplete((dict, t) -> dictRef.set(dict)); + + assertThat(result.isCompletedExceptionally()) + .as("Result should be completed exceptionally") + .isTrue(); + assertThat(trainer.getTrainingStatus()) + .as("Status should be FAILED") + .isEqualTo(TrainingStatus.FAILED); + assertThat(dictRef.get()) + .as("Dictionary reference should be null") + .isNull(); + } + + @Test + public void testDictionaryTrainedListener() + { + trainer.start(true); + addSampleData(testConfig.acceptableTotalSampleSize); + + // Train dictionary synchronously - callback should be called + CompressionDictionary dictionary = trainer.trainDictionary(false); + + // Verify callback was invoked with the dictionary + assertThat(callbackResult.get()).as("Callback should have been called").isNotNull(); + assertThat(callbackResult.get().identifier().id) + .as("Callback should receive correct dictionary ID") + .isEqualTo(dictionary.identifier().id); + assertThat(callbackResult.get().kind()) + .as("Callback should receive correct dictionary kind") + .isEqualTo(dictionary.kind()); + } + + @Test + public void testMonotonicDictionaryIds() + { + long now = System.currentTimeMillis(); + long id1 = CompressionDictionary.DictId.makeDictId(now, 100L); + long hourLater= now + TimeUnit.HOURS.toMillis(1); + long id2 = CompressionDictionary.DictId.makeDictId(hourLater, 200L); + long id3 = CompressionDictionary.DictId.makeDictId(now, 200L); + + assertThat(id2) + .as("Dictionary IDs should be monotonic over time") + .isGreaterThan(id1) + .isGreaterThan(id3); + + assertThat(id3).isNotEqualTo(id1).isNotEqualTo(id2); + } + + @Test + public void testIsCompatibleWith() + { + CompressionParams compatibleParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "3")); + + assertThat(trainer.isCompatibleWith(compatibleParams)) + .as("Should be compatible with same compression level") + .isTrue(); + + + CompressionParams incompatibleParams = CompressionParams.lz4(); + + assertThat(trainer.isCompatibleWith(incompatibleParams)) + .as("Should not be compatible with different compressor") + .isFalse(); + + CompressionParams differentLevelParams = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true, + Map.of("compression_level", "4")); + + assertThat(trainer.isCompatibleWith(differentLevelParams)) + .as("Should not be compatible with different compression level") + .isFalse(); + + CompressionParams disabledParams = CompressionParams.noCompression(); + + assertThat(trainer.isCompatibleWith(disabledParams)) + .as("Should not be compatible with disabled compression") + .isFalse(); + } + + @Test + public void testUpdateSamplingRate() + { + trainer.start(true); + + // Test updating to different valid sampling rates + trainer.updateSamplingRate(10); + + // With sampling rate 10 (10%), should mostly return false + int sampleCount = 0; + int iterations = 1000; + for (int i = 0; i < iterations; i++) + { + if (trainer.shouldSample()) + { + sampleCount++; + } + } + + // Should be roughly 10% (1 out of 10), allow some variance + assertThat(sampleCount) + .as("Sample rate should be approximately 10%") + .isGreaterThan(iterations / 20) // at least 5% + .isLessThan(iterations / 5); // at most 20% + + // Test updating to 100% sampling + trainer.updateSamplingRate(1); + + // Should always sample now + for (int i = 0; i < 10; i++) + { + assertThat(trainer.shouldSample()) + .as("Should always sample with rate 1") + .isTrue(); + } + } + + @Test + public void testUpdateSamplingRateValidation() + { + trainer.start(true); + + // Test invalid sampling rates + assertThatThrownBy(() -> trainer.updateSamplingRate(0)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Sampling rate must be positive"); + + assertThatThrownBy(() -> trainer.updateSamplingRate(-1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Sampling rate must be positive"); + + assertThatThrownBy(() -> trainer.updateSamplingRate(-100)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Sampling rate must be positive"); + } + + @Test + public void testUpdateSamplingRateBeforeStart() + { + // Should be able to update sampling rate even before start + trainer.updateSamplingRate(5); + + trainer.start(true); + + // Verify the updated rate is used after start + int sampleCount = 0; + int iterations = 1000; + for (int i = 0; i < iterations; i++) + { + if (trainer.shouldSample()) + { + sampleCount++; + } + } + + // Should be roughly 20% (1 out of 5), allow some variance + assertThat(sampleCount) + .as("Sample rate should be approximately 20%") + .isGreaterThan(iterations / 10) // at least 10% + .isLessThan(iterations / 2); // at most 50% + } + + private CompletableFuture startTraining(boolean manualTraining, boolean forceTrain, int sampleSize) throws Exception + { + trainer.start(manualTraining); + if (sampleSize > 0) + { + addSampleData(sampleSize); + } + + if (forceTrain) + { + assertThat(trainer.isReady()) + .as("Trainer should not be ready to train due to lack of samples") + .isFalse(); + } + + CountDownLatch latch = new CountDownLatch(1); + CompletableFuture future = trainer.trainDictionaryAsync(forceTrain) + .whenComplete((dict, throwable) -> latch.countDown()); + assertThat(latch.await(10, TimeUnit.SECONDS)) + .as("Training should complete within timeout") + .isTrue(); + return future; + } + + private void addSampleData(int totalSize) + { + byte[] sampleBytes = SAMPLE_DATA.getBytes(); + int samplesNeeded = (totalSize + sampleBytes.length - 1) / sampleBytes.length; // Round up + + for (int i = 0; i < samplesNeeded; i++) + { + ByteBuffer sample = ByteBuffer.wrap(sampleBytes); + trainer.addSample(sample); + } + } +} diff --git a/test/unit/org/apache/cassandra/io/compress/CQLCompressionTest.java b/test/unit/org/apache/cassandra/io/compress/CQLCompressionTest.java index fbc17c272992..f2a56c33640c 100644 --- a/test/unit/org/apache/cassandra/io/compress/CQLCompressionTest.java +++ b/test/unit/org/apache/cassandra/io/compress/CQLCompressionTest.java @@ -81,11 +81,11 @@ public void zstdParamsTest() { createTable("create table %s (id int primary key, uh text) with compression = {'class':'ZstdCompressor', 'compression_level':-22}"); assertTrue(((ZstdCompressor)getCurrentColumnFamilyStore().metadata().params.compression.getSstableCompressor()).getClass().equals(ZstdCompressor.class)); - assertEquals(((ZstdCompressor)getCurrentColumnFamilyStore().metadata().params.compression.getSstableCompressor()).getCompressionLevel(), -22); + assertEquals(((ZstdCompressor)getCurrentColumnFamilyStore().metadata().params.compression.getSstableCompressor()).compressionLevel(), -22); createTable("create table %s (id int primary key, uh text) with compression = {'class':'ZstdCompressor'}"); assertTrue(((ZstdCompressor)getCurrentColumnFamilyStore().metadata().params.compression.getSstableCompressor()).getClass().equals(ZstdCompressor.class)); - assertEquals(((ZstdCompressor)getCurrentColumnFamilyStore().metadata().params.compression.getSstableCompressor()).getCompressionLevel(), ZstdCompressor.DEFAULT_COMPRESSION_LEVEL); + assertEquals(((ZstdCompressor)getCurrentColumnFamilyStore().metadata().params.compression.getSstableCompressor()).compressionLevel(), ZstdCompressor.DEFAULT_COMPRESSION_LEVEL); } @Test(expected = ConfigurationException.class) diff --git a/test/unit/org/apache/cassandra/io/compress/CompressedRandomAccessReaderTest.java b/test/unit/org/apache/cassandra/io/compress/CompressedRandomAccessReaderTest.java index 2bf127f0790d..056761b5602c 100644 --- a/test/unit/org/apache/cassandra/io/compress/CompressedRandomAccessReaderTest.java +++ b/test/unit/org/apache/cassandra/io/compress/CompressedRandomAccessReaderTest.java @@ -314,4 +314,4 @@ private static void updateChecksum(RandomAccessFile file, long checksumOffset, b file.write(checksum); SyncUtil.sync(file.getFD()); } -} \ No newline at end of file +} diff --git a/test/unit/org/apache/cassandra/io/compress/CompressedSequentialWriterTest.java b/test/unit/org/apache/cassandra/io/compress/CompressedSequentialWriterTest.java index afa469c48772..ada098e6d057 100644 --- a/test/unit/org/apache/cassandra/io/compress/CompressedSequentialWriterTest.java +++ b/test/unit/org/apache/cassandra/io/compress/CompressedSequentialWriterTest.java @@ -395,4 +395,4 @@ void cleanup() } } -} \ No newline at end of file +} diff --git a/test/unit/org/apache/cassandra/io/compress/CompressionMetadataTest.java b/test/unit/org/apache/cassandra/io/compress/CompressionMetadataTest.java index 321fe5735606..560a0af63520 100644 --- a/test/unit/org/apache/cassandra/io/compress/CompressionMetadataTest.java +++ b/test/unit/org/apache/cassandra/io/compress/CompressionMetadataTest.java @@ -42,7 +42,8 @@ private CompressionMetadata newCompressionMetadata(Memory memory) memory, memory.size(), dataLength, - compressedFileLength); + compressedFileLength, + null); } @Test @@ -75,4 +76,4 @@ public void testMemoryIsShared() assertThat(copy.isCleanedUp()).isTrue(); assertThatExceptionOfType(AssertionError.class).isThrownBy(memory::size); } -} \ No newline at end of file +} diff --git a/test/unit/org/apache/cassandra/io/compress/ZstdCompressorTest.java b/test/unit/org/apache/cassandra/io/compress/ZstdCompressorTest.java index 70e32ad4d252..b360280b9e32 100644 --- a/test/unit/org/apache/cassandra/io/compress/ZstdCompressorTest.java +++ b/test/unit/org/apache/cassandra/io/compress/ZstdCompressorTest.java @@ -36,7 +36,7 @@ public class ZstdCompressorTest public void emptyConfigurationUsesDefaultCompressionLevel() { ZstdCompressor compressor = ZstdCompressor.create(Collections.emptyMap()); - assertEquals(ZstdCompressor.DEFAULT_COMPRESSION_LEVEL, compressor.getCompressionLevel()); + assertEquals(ZstdCompressor.DEFAULT_COMPRESSION_LEVEL, compressor.compressionLevel()); } @Test(expected = IllegalArgumentException.class) diff --git a/test/unit/org/apache/cassandra/io/compress/ZstdDictionaryCompressorTest.java b/test/unit/org/apache/cassandra/io/compress/ZstdDictionaryCompressorTest.java new file mode 100644 index 000000000000..ef5e3d11d8d7 --- /dev/null +++ b/test/unit/org/apache/cassandra/io/compress/ZstdDictionaryCompressorTest.java @@ -0,0 +1,392 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.io.compress; + +import com.github.luben.zstd.Zstd; +import com.github.luben.zstd.ZstdDictTrainer; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.compression.ZstdCompressionDictionary; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.Map; +import java.util.Random; + +import static org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import static org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.Assert.fail; + +public class ZstdDictionaryCompressorTest +{ + private static final int TEST_DATA_SIZE = 1024; + private static final String REPEATED_PATTERN = "The quick brown fox jumps over the lazy dog. "; + + private static byte[] testData; + private static byte[] compressibleData; + private static ZstdCompressionDictionary testDictionary; + + @BeforeClass + public static void setup() + { + DatabaseDescriptor.daemonInitialization(); + testData = new byte[TEST_DATA_SIZE]; + new Random(42).nextBytes(testData); + + // Generate compressible data + StringBuilder sb = new StringBuilder(); + while (sb.length() < TEST_DATA_SIZE) + { + sb.append(REPEATED_PATTERN); + } + compressibleData = sb.substring(0, TEST_DATA_SIZE).getBytes(); + testDictionary = createTestDictionary(); + } + + @AfterClass + public static void tearDown() + { + if (testDictionary != null) + { + testDictionary.close(); + } + } + + @Test + public void testCreateWithOptions() + { + Map options = Map.of(ZstdCompressor.COMPRESSION_LEVEL_OPTION_NAME, "5"); + + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(options); + assertThat(compressor).isNotNull(); + assertThat(compressor.compressionLevel()).isEqualTo(5); + assertThat(compressor.dictionary()).isNull(); // No dictionary should be set + } + + @Test + public void testCreateWithEmptyOptions() + { + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(Collections.emptyMap()); + assertThat(compressor).isNotNull(); + assertThat(compressor.compressionLevel()).isEqualTo(ZstdCompressor.DEFAULT_COMPRESSION_LEVEL); + } + + @Test + public void testCreateWithDictionary() + { + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(testDictionary); + assertThat(compressor).isNotNull(); + assertThat(compressor.compressionLevel()).isEqualTo(ZstdCompressor.DEFAULT_COMPRESSION_LEVEL); + assertThat(compressor.dictionary()).isSameAs(testDictionary); + } + + @Test + public void testCreateWithInvalidCompressionLevel() + { + String invalidLevel = String.valueOf(Zstd.maxCompressionLevel() + 1); + Map options = Map.of(ZstdCompressor.COMPRESSION_LEVEL_OPTION_NAME, invalidLevel); + + assertThatThrownBy(() -> ZstdDictionaryCompressor.create(options)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(ZstdCompressor.COMPRESSION_LEVEL_OPTION_NAME + '=' + invalidLevel + " is invalid"); + } + + @Test + public void testCompressDecompressWithDictionary() throws IOException + { + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(testDictionary); + + ByteBuffer input = ByteBuffer.allocateDirect(compressibleData.length); + input.put(compressibleData); + input.flip(); + + ByteBuffer compressed = ByteBuffer.allocateDirect(compressor.initialCompressedBufferLength(compressibleData.length)); + + // Compress + compressor.compress(input, compressed); + compressed.flip(); + + assertThat(compressed.remaining()) + .as("Data should be compressed") + .isLessThan(compressibleData.length); + + // Decompress + ByteBuffer decompressed = ByteBuffer.allocateDirect(compressibleData.length); + compressed.rewind(); + compressor.uncompress(compressed, decompressed); + decompressed.flip(); + + // Verify roundtrip + byte[] result = new byte[decompressed.remaining()]; + decompressed.get(result); + assertThat(result).isEqualTo(compressibleData); + } + + @Test + public void testCompressDecompressWithoutDictionary() throws IOException + { + // Test fallback behavior when no dictionary is provided + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(Collections.emptyMap()); + + ByteBuffer input = ByteBuffer.allocateDirect(testData.length); + input.put(testData); + input.flip(); + + ByteBuffer compressed = ByteBuffer.allocateDirect(compressor.initialCompressedBufferLength(testData.length)); + + // Compress + compressor.compress(input, compressed); + compressed.flip(); + + // Decompress + ByteBuffer decompressed = ByteBuffer.allocateDirect(testData.length); + compressed.rewind(); + compressor.uncompress(compressed, decompressed); + decompressed.flip(); + + // Verify roundtrip + byte[] result = new byte[decompressed.remaining()]; + decompressed.get(result); + assertThat(result).isEqualTo(testData); + } + + @Test + public void testCompressDecompressByteArray() throws IOException + { + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(testDictionary); + + // Test byte array compression/decompression using direct buffers + ByteBuffer input = ByteBuffer.allocateDirect(compressibleData.length); + input.put(compressibleData); + input.flip(); + + ByteBuffer output = ByteBuffer.allocateDirect(compressor.initialCompressedBufferLength(compressibleData.length)); + + compressor.compress(input, output); + int compressedLength = output.position(); + + // Extract compressed data to byte array for array-based decompression test + byte[] compressed = new byte[compressedLength]; + output.flip(); + output.get(compressed); + + // Decompress using byte array method + byte[] decompressed = new byte[compressibleData.length]; + int decompressedLength = compressor.uncompress(compressed, 0, compressedLength, decompressed, 0); + + assertThat(decompressedLength).isEqualTo(compressibleData.length); + assertThat(decompressed).isEqualTo(compressibleData); + } + + @Test + public void testDictionaryCompressionImprovement() + { + // Test that dictionary compression provides better compression ratio + ZstdDictionaryCompressor dictCompressor = ZstdDictionaryCompressor.create(testDictionary); + ZstdDictionaryCompressor noDictCompressor = ZstdDictionaryCompressor.create(Collections.emptyMap()); + + ByteBuffer input1 = ByteBuffer.allocateDirect(compressibleData.length); + input1.put(compressibleData); + input1.flip(); + + ByteBuffer input2 = ByteBuffer.allocateDirect(compressibleData.length); + input2.put(compressibleData); + input2.flip(); + + ByteBuffer dictCompressed = ByteBuffer.allocateDirect(dictCompressor.initialCompressedBufferLength(compressibleData.length)); + ByteBuffer noDictCompressed = ByteBuffer.allocateDirect(noDictCompressor.initialCompressedBufferLength(compressibleData.length)); + + try + { + dictCompressor.compress(input1, dictCompressed); + noDictCompressor.compress(input2, noDictCompressed); + + dictCompressed.flip(); + noDictCompressed.flip(); + + // Dictionary compression should achieve better compression ratio for repetitive data + assertThat(dictCompressed.remaining()) + .as("Dictionary compression should achieve better compression ratio") + .isLessThanOrEqualTo(noDictCompressed.remaining()); + } + catch (IOException e) + { + fail("Compression should not fail: " + e.getMessage()); + } + } + + @Test + public void testCompressorCaching() + { + // Test that same dictionary returns same compressor instance + ZstdDictionaryCompressor compressor1 = ZstdDictionaryCompressor.create(testDictionary); + ZstdDictionaryCompressor compressor2 = ZstdDictionaryCompressor.create(testDictionary); + + assertThat(compressor1) + .as("Same dictionary should return cached compressor instance") + .isSameAs(compressor2); + } + + @Test + public void testGetOrCopyWithDictionary() + { + ZstdDictionaryCompressor originalCompressor = ZstdDictionaryCompressor.create(Collections.emptyMap()); + ZstdDictionaryCompressor dictCompressor = originalCompressor.getOrCopyWithDictionary(testDictionary); + + assertThat(dictCompressor) + .as("Should return different compressor instance") + .isNotSameAs(originalCompressor); + assertThat(dictCompressor.dictionary()) + .as("Should have the provided dictionary") + .isSameAs(testDictionary); + assertThat(dictCompressor.compressionLevel()) + .as("Should preserve compression level") + .isEqualTo(originalCompressor.compressionLevel()); + } + + @Test + public void testGetOrCopyWithSameDictionary() + { + ZstdDictionaryCompressor originalCompressor = ZstdDictionaryCompressor.create(testDictionary); + ZstdDictionaryCompressor sameCompressor = originalCompressor.getOrCopyWithDictionary(testDictionary); + + assertThat(sameCompressor) + .as("Same dictionary should return same compressor") + .isSameAs(originalCompressor); + } + + @Test + public void testClosedDictionaryHandling() + { + ZstdDictionaryCompressor.invalidateCache(); + ZstdCompressionDictionary closedDict = createTestDictionary(); + closedDict.close(); + + // This should throw IllegalStateException + assertThatThrownBy(() -> ZstdDictionaryCompressor.create(closedDict)) + .isInstanceOf(IllegalStateException.class); + } + + @Test + public void testCompressionWithNullDictionary() throws IOException + { + // Test that null dictionary falls back to standard compression + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create((ZstdCompressionDictionary) null); + + ByteBuffer input = ByteBuffer.allocateDirect(testData.length); + input.put(testData); + input.flip(); + + ByteBuffer compressed = ByteBuffer.allocateDirect(compressor.initialCompressedBufferLength(testData.length)); + + // Should not throw exception, should fall back to standard Zstd + compressor.compress(input, compressed); + compressed.flip(); + + ByteBuffer decompressed = ByteBuffer.allocateDirect(testData.length); + compressed.rewind(); + compressor.uncompress(compressed, decompressed); + decompressed.flip(); + + byte[] result = new byte[decompressed.remaining()]; + decompressed.get(result); + assertThat(result) + .as("Null dictionary should fall back to standard compression") + .isEqualTo(testData); + } + + @Test + public void testDecompressionFailureHandling() + { + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(testDictionary); + + // Create invalid compressed data + byte[] invalidData = new byte[10]; + new Random().nextBytes(invalidData); + + byte[] output = new byte[100]; + + assertThatThrownBy(() -> compressor.uncompress(invalidData, 0, invalidData.length, output, 0)) + .isInstanceOf(IOException.class) + .hasMessageContaining("Decompression failed"); + } + + @Test + public void testAcceptableDictionaryKind() + { + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(Collections.emptyMap()); + assertThat(compressor.acceptableDictionaryKind()) + .as("Should accept ZSTD dictionary kind") + .isEqualTo(Kind.ZSTD); + } + + @Test + public void testEmptyDataCompression() throws IOException + { + ZstdDictionaryCompressor compressor = ZstdDictionaryCompressor.create(testDictionary); + + byte[] emptyData = new byte[0]; + ByteBuffer input = ByteBuffer.allocateDirect(emptyData.length + 1); // Allocate at least 1 byte for direct buffer + input.put(emptyData); + input.flip(); + + ByteBuffer compressed = ByteBuffer.allocateDirect(Math.max(1, compressor.initialCompressedBufferLength(0))); + + compressor.compress(input, compressed); + compressed.flip(); + + ByteBuffer decompressed = ByteBuffer.allocateDirect(1); // Allocate at least 1 byte for direct buffer + compressed.rewind(); + compressor.uncompress(compressed, decompressed); + + assertThat(decompressed.position()) + .as("Should have written nothing for empty data") + .isEqualTo(0); + } + + private static ZstdCompressionDictionary createTestDictionary() + { + try + { + int sampleSize = 100 * 1024; + int dictSize = 6 * 1024; + // Create a simple dictionary from repetitive data + ZstdDictTrainer trainer = new ZstdDictTrainer(sampleSize, dictSize, 3); + + for (int i = 0; i < 1000; i++) + { + trainer.addSample(compressibleData); + } + + byte[] dictBytes = trainer.trainSamples(); + DictId dictId = new DictId(Kind.ZSTD, 1); + + return new ZstdCompressionDictionary(dictId, dictBytes); + } + catch (Exception e) + { + throw new RuntimeException("Failed to create test dictionary", e); + } + } +} diff --git a/test/unit/org/apache/cassandra/io/sstable/ScrubTest.java b/test/unit/org/apache/cassandra/io/sstable/ScrubTest.java index bd8f846572ff..7696d5d9d59e 100644 --- a/test/unit/org/apache/cassandra/io/sstable/ScrubTest.java +++ b/test/unit/org/apache/cassandra/io/sstable/ScrubTest.java @@ -559,7 +559,7 @@ public static void overrideWithGarbage(SSTableReader sstable, ByteBuffer key1, B if (compression) { // overwrite with garbage the compression chunks from key1 to key2 - CompressionMetadata compData = CompressionInfoComponent.load(sstable.descriptor); + CompressionMetadata compData = CompressionInfoComponent.load(sstable.descriptor, null); CompressionMetadata.Chunk chunk1 = compData.chunkFor( sstable.getPosition(PartitionPosition.ForKey.get(key1, sstable.getPartitioner()), SSTableReader.Operator.EQ)); diff --git a/test/unit/org/apache/cassandra/schema/CompressionParamsTest.java b/test/unit/org/apache/cassandra/schema/CompressionParamsTest.java new file mode 100644 index 000000000000..70f1fbbc1938 --- /dev/null +++ b/test/unit/org/apache/cassandra/schema/CompressionParamsTest.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.schema; + +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.config.DatabaseDescriptor; + +import static org.assertj.core.api.Assertions.assertThat; + +public class CompressionParamsTest +{ + @BeforeClass + public static void beforeClass() + { + DatabaseDescriptor.daemonInitialization(); + } + + @Test + public void testIsDictionaryCompressionEnabled() + { + CompressionParams noCompression = CompressionParams.noCompression(); + assertThat(noCompression.isDictionaryCompressionEnabled()) + .as("No compression should not enable dictionary compression") + .isFalse(); + + CompressionParams regularZstd = CompressionParams.zstd(); + assertThat(regularZstd.isDictionaryCompressionEnabled()) + .as("Regular Zstd compression should not enable dictionary compression") + .isFalse(); + + CompressionParams zstdDictionary = CompressionParams.zstd(CompressionParams.DEFAULT_CHUNK_LENGTH, true); + assertThat(zstdDictionary.isDictionaryCompressionEnabled()) + .as("Zstd dictionary compression should enable dictionary compression") + .isTrue(); + + CompressionParams lz4 = CompressionParams.lz4(); + assertThat(lz4.isDictionaryCompressionEnabled()) + .as("LZ4 compression should not enable dictionary compression") + .isFalse(); + + CompressionParams snappy = CompressionParams.snappy(); + assertThat(snappy.isDictionaryCompressionEnabled()) + .as("Snappy compression should not enable dictionary compression") + .isFalse(); + + CompressionParams deflate = CompressionParams.deflate(); + assertThat(deflate.isDictionaryCompressionEnabled()) + .as("Deflate compression should not enable dictionary compression") + .isFalse(); + + CompressionParams noop = CompressionParams.noop(); + assertThat(noop.isDictionaryCompressionEnabled()) + .as("Noop compression should not enable dictionary compression") + .isFalse(); + } +} diff --git a/test/unit/org/apache/cassandra/schema/SystemDistributedKeyspaceCompressionDictionaryTest.java b/test/unit/org/apache/cassandra/schema/SystemDistributedKeyspaceCompressionDictionaryTest.java new file mode 100644 index 000000000000..ca0c50e79047 --- /dev/null +++ b/test/unit/org/apache/cassandra/schema/SystemDistributedKeyspaceCompressionDictionaryTest.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.schema; + +import java.util.List; +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; + +import org.apache.cassandra.cql3.CQLTester; +import org.apache.cassandra.cql3.QueryProcessor; +import org.apache.cassandra.db.compression.CompressionDictionary; +import org.apache.cassandra.db.compression.CompressionDictionary.DictId; +import org.apache.cassandra.db.compression.CompressionDictionary.Kind; +import org.apache.cassandra.db.compression.ZstdCompressionDictionary; + +import static org.assertj.core.api.Assertions.assertThat; + +public class SystemDistributedKeyspaceCompressionDictionaryTest extends CQLTester +{ + private static final String TEST_KEYSPACE = "test_keyspace"; + private static final String TEST_TABLE = "test_table"; + private static final String OTHER_TABLE = "other_table"; + + private CompressionDictionary testDictionary1; + private CompressionDictionary testDictionary2; + + @Before + public void setUp() + { + DictId dictId1 = new DictId(Kind.ZSTD, 100L); + DictId dictId2 = new DictId(Kind.ZSTD, 200L); + + byte[] dictData1 = "test dictionary data 1".getBytes(); + byte[] dictData2 = "test dictionary data 2".getBytes(); + + testDictionary1 = new ZstdCompressionDictionary(dictId1, dictData1); + testDictionary2 = new ZstdCompressionDictionary(dictId2, dictData2); + + clearCompressionDictionaries(); + } + + @Test + public void testCompressionDictionariesTableExists() + { + Set tableNames = SystemDistributedKeyspace.TABLE_NAMES; + + assertThat(tableNames) + .as("TABLE_NAMES should contain compression_dictionaries") + .contains(SystemDistributedKeyspace.COMPRESSION_DICTIONARIES); + + // Verify the table exists in the schema + KeyspaceMetadata systemDistributedKs = SystemDistributedKeyspace.metadata(); + TableMetadata compressionDictTable = systemDistributedKs + .getTableOrViewNullable(SystemDistributedKeyspace.COMPRESSION_DICTIONARIES); + + assertThat(compressionDictTable) + .as("compression_dictionaries table should exist in schema") + .isNotNull(); + } + + @Test + public void testStoreCompressionDictionary() throws Exception + { + // Store a dictionary + SystemDistributedKeyspace.storeCompressionDictionary(TEST_KEYSPACE, TEST_TABLE, testDictionary1); + + // Verify it was stored + CompressionDictionary retrieved = SystemDistributedKeyspace.retrieveLatestCompressionDictionary( + TEST_KEYSPACE, TEST_TABLE); + + assertThat(retrieved) + .as("Retrieved dictionary should not be null") + .isNotNull(); + + assertThat(retrieved.identifier()) + .as("Retrieved dictionary ID should match stored") + .isEqualTo(testDictionary1.identifier()); + + assertThat(retrieved.kind()) + .as("Retrieved dictionary kind should match stored") + .isEqualTo(testDictionary1.kind()); + + assertThat(retrieved.rawDictionary()) + .as("Retrieved dictionary data should match stored") + .isEqualTo(testDictionary1.rawDictionary()); + + retrieved.close(); + } + + @Test + public void testStoreMultipleDictionaries() throws Exception + { + // Store multiple dictionaries for the same table + SystemDistributedKeyspace.storeCompressionDictionary(TEST_KEYSPACE, TEST_TABLE, testDictionary1); + SystemDistributedKeyspace.storeCompressionDictionary(TEST_KEYSPACE, TEST_TABLE, testDictionary2); + + // Should retrieve the latest one (higher ID due to clustering order) + CompressionDictionary latest = SystemDistributedKeyspace.retrieveLatestCompressionDictionary( + TEST_KEYSPACE, TEST_TABLE); + + assertThat(latest) + .as("Should retrieve the latest dictionary") + .isNotNull(); + + assertThat(latest.identifier()) + .as("Should retrieve dictionary with higher ID") + .isEqualTo(testDictionary2.identifier()); + + latest.close(); + } + + @Test + public void testRetrieveSpecificDictionary() throws Exception + { + // Store both dictionaries + SystemDistributedKeyspace.storeCompressionDictionary(TEST_KEYSPACE, TEST_TABLE, testDictionary1); + SystemDistributedKeyspace.storeCompressionDictionary(TEST_KEYSPACE, TEST_TABLE, testDictionary2); + + // Retrieve specific dictionary by ID + CompressionDictionary dict1 = SystemDistributedKeyspace.retrieveCompressionDictionary( + TEST_KEYSPACE, TEST_TABLE, new DictId(Kind.ZSTD, 100L)); + CompressionDictionary dict2 = SystemDistributedKeyspace.retrieveCompressionDictionary( + TEST_KEYSPACE, TEST_TABLE, new DictId(Kind.ZSTD, 200L)); + + assertThat(dict1) + .as("Should retrieve dictionary 1") + .isNotNull(); + + assertThat(dict1.identifier()) + .as("Should retrieve correct dictionary by ID") + .isEqualTo(testDictionary1.identifier()); + + assertThat(dict2) + .as("Should retrieve dictionary 2") + .isNotNull(); + + assertThat(dict2.identifier()) + .as("Should retrieve correct dictionary by ID") + .isEqualTo(testDictionary2.identifier()); + + dict1.close(); + dict2.close(); + } + + @Test + public void testRetrieveNonExistentDictionary() + { + // Try to retrieve dictionary that doesn't exist + CompressionDictionary nonExistent = SystemDistributedKeyspace.retrieveLatestCompressionDictionary( + "nonexistent_keyspace", "nonexistent_table"); + + assertThat(nonExistent) + .as("Should return null for non-existent dictionary") + .isNull(); + + // Try to retrieve specific dictionary that doesn't exist + CompressionDictionary nonExistentById = SystemDistributedKeyspace.retrieveCompressionDictionary( + TEST_KEYSPACE, TEST_TABLE, new DictId(Kind.ZSTD, 999L)); + + assertThat(nonExistentById) + .as("Should return null for non-existent dictionary ID") + .isNull(); + } + + private void clearCompressionDictionaries() + { + for (String table : List.of(TEST_TABLE, OTHER_TABLE)) + { + String deleteQuery = String.format("DELETE FROM %s.%s WHERE keyspace_name = '%s' AND table_name = '%s'", + SchemaConstants.DISTRIBUTED_KEYSPACE_NAME, + SystemDistributedKeyspace.COMPRESSION_DICTIONARIES, + TEST_KEYSPACE, + table); + QueryProcessor.executeInternal(deleteQuery); + } + } +} diff --git a/test/unit/org/apache/cassandra/streaming/compression/CompressedInputStreamTest.java b/test/unit/org/apache/cassandra/streaming/compression/CompressedInputStreamTest.java index 391d58972106..44e5bc2719e8 100644 --- a/test/unit/org/apache/cassandra/streaming/compression/CompressedInputStreamTest.java +++ b/test/unit/org/apache/cassandra/streaming/compression/CompressedInputStreamTest.java @@ -142,7 +142,7 @@ private void testCompressedReadWith(long[] valuesToCheck, boolean testTruncate, writer.finish(); } - CompressionMetadata comp = CompressionInfoComponent.load(desc); + CompressionMetadata comp = CompressionInfoComponent.load(desc, null); List sections = new ArrayList<>(); for (long l : valuesToCheck) { diff --git a/test/unit/org/apache/cassandra/tools/ToolRunner.java b/test/unit/org/apache/cassandra/tools/ToolRunner.java index 3d4d4588fdd7..143b7cfbd693 100644 --- a/test/unit/org/apache/cassandra/tools/ToolRunner.java +++ b/test/unit/org/apache/cassandra/tools/ToolRunner.java @@ -666,6 +666,13 @@ public AssertHelp errorContainsAny(String... messages) return this; } + public AssertHelp stdoutContains(String message) + { + assertThat(message).hasSizeGreaterThan(0); + assertThat(stdout).isNotNull().contains(message); + return this; + } + private void fail(String msg) { StringBuilder sb = new StringBuilder(); diff --git a/test/unit/org/apache/cassandra/tools/nodetool/TrainCompressionDictionaryTest.java b/test/unit/org/apache/cassandra/tools/nodetool/TrainCompressionDictionaryTest.java new file mode 100644 index 000000000000..9b8ffc097684 --- /dev/null +++ b/test/unit/org/apache/cassandra/tools/nodetool/TrainCompressionDictionaryTest.java @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.tools.nodetool; + +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.cql3.CQLTester; +import org.apache.cassandra.tools.ToolRunner; + +import static org.apache.cassandra.tools.ToolRunner.invokeNodetool; +import static org.assertj.core.api.Assertions.assertThat; + +public class TrainCompressionDictionaryTest extends CQLTester +{ + @BeforeClass + public static void setup() throws Throwable + { + requireNetwork(); + startJMXServer(); + } + + @Test + public void testTrainCommandSuccess() + { + // Create a table with dictionary compression enabled + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + + // Add some data to make training meaningful + for (int i = 0; i < 100; i++) + { + execute("INSERT INTO %s (id, data) VALUES (?, ?)", i, "This is sample data for compression dictionary training " + i); + } + flush(keyspace()); + + // Test async training command + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", "--async", keyspace(), table); + result.assertOnCleanExit(); + + assertThat(result.getStdout()) + .as("Should indicate training started") + .contains("Training started asynchronously") + .contains(keyspace()) + .contains(table); + } + + @Test + public void testTrainCommandWithCustomDuration() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + + // Add test data + for (int i = 0; i < 50; i++) + { + execute("INSERT INTO %s (id, data) VALUES (?, ?)", i, "Sample text for dictionary training " + i); + } + flush(keyspace()); + + // Test with custom sampling duration + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + "--async", + "--max-sampling-duration", "30", + keyspace(), + table); + result.assertOnCleanExit(); + + assertThat(result.getStdout()) + .as("Should use custom sampling duration") + .contains("Will collect samples for up to 30 seconds"); + } + + @Test + public void testStatusCommand() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + + // Check status before any training + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + "--status", + keyspace(), + table); + result.assertOnCleanExit(); + + assertThat(result.getStdout()) + .as("Should show initial status") + .containsAnyOf("Trainer is not running", "Trainer is collecting sample data"); + } + + @Test + public void testStatusAfterTrainingStart() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + + // Add data + for (int i = 0; i < 20; i++) + { + execute("INSERT INTO %s (id, data) VALUES (?, ?)", i, "Data for training " + i); + } + flush(keyspace()); + + // Start training asynchronously + invokeNodetool("traincompressiondictionary", "--async", keyspace(), table) + .assertOnCleanExit(); + + // Check status - should show SAMPLING or TRAINING + ToolRunner.ToolResult statusResult = invokeNodetool("traincompressiondictionary", + "--status", + keyspace(), + table); + statusResult.assertOnCleanExit(); + + assertThat(statusResult.getStdout()) + .as("Should show training in progress") + .containsAnyOf("collecting sample data", "Training is in progress"); + } + + @Test + public void testInvalidKeyspace() + { + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + "--status", + "nonexistent_keyspace", + "nonexistent_table"); + result.asserts() + .failure() + .errorContains("Failed to get training status"); + } + + @Test + public void testInvalidTable() + { + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + "--status", + keyspace(), + "nonexistent_table"); + result.asserts() + .failure() + .errorContains("Failed to get training status"); + } + + @Test + public void testTrainingOnNonDictionaryTable() + { + // Create table without dictionary compression + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'LZ4Compressor'}"); + + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + "--async", + keyspace(), + table); + result.asserts() + .failure() + .errorContains("does not support dictionary compression"); + } + + @Test + public void testTrainingWithoutDictionaryCompressionEnabled() + { + // Create table with Zstd but without dictionary compression + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdCompressor'}"); + + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + "--async", + keyspace(), + table); + result.asserts() + .failure() + .errorContains("does not support dictionary compression"); + } + + @Test + public void testInvalidSamplingDuration() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + + // Test with invalid (negative) duration + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + "--async", + "--max-sampling-duration", "-10", + keyspace(), + table); + + // Command line parser should handle this validation + result.asserts() + .failure(); + } + + @Test + public void testHelpOutput() + { + ToolRunner.ToolResult result = invokeNodetool("help", "traincompressiondictionary"); + result.assertOnCleanExit(); + + assertThat(result.getStdout()) + .as("Should show command help") + .contains("nodetool traincompressiondictionary - Manually trigger compression") + .contains("dictionary training for a table") + .contains("keyspace name") + .contains("table name") + .contains("-a, --async") + .contains("-d , --max-sampling-duration") + .contains("-r , --sampling-rate") + .contains("-s, --status"); + } + + @Test + public void testAllStatusValues() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + + // Test NOT_STARTED status + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + "--status", + keyspace(), + table); + result.assertOnCleanExit(); + + String output = result.getStdout(); + assertThat(output) + .as("Should handle NOT_STARTED status appropriately") + .satisfiesAnyOf(out -> assertThat(out).contains("not running"), + out -> assertThat(out).contains("NOT_STARTED")); + } + + @Test + public void testCommandLineArgumentParsing() + { + // Test missing required arguments + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary"); + result.asserts() + .failure() + .stdoutContains("Missing required parameter"); + + // Test missing table argument + result = invokeNodetool("traincompressiondictionary", keyspace()); + result.asserts() + .failure() + .stdoutContains("Missing required parameter"); + } + + @Test + public void testMutuallyExclusiveOptions() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + + // Both --async and --status should work independently + invokeNodetool("traincompressiondictionary", "--async", keyspace(), table) + .assertOnCleanExit(); + + invokeNodetool("traincompressiondictionary", "--status", keyspace(), table) + .assertOnCleanExit(); + } + + @Test + public void testStatusOutputFormatting() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + "--status", + keyspace(), + table); + result.assertOnCleanExit(); + + assertThat(result.getStdout()) + .as("Status output should include keyspace and table names") + .contains(keyspace()) + .contains(table); + } + + @Test + public void testSamplingRateOption() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + + // Add some data + for (int i = 0; i < 20; i++) + { + execute("INSERT INTO %s (id, data) VALUES (?, ?)", i, "Data for sampling rate test " + i); + } + flush(keyspace()); + + // Test with valid sampling rates + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + "--async", + "--sampling-rate", "0.5", + keyspace(), + table); + result.assertOnCleanExit(); + + assertThat(result.getStdout()) + .as("Should show sampling rate was used") + .contains("Using sampling rate: 0.50 (50.0%)"); + } + + @Test + public void testInvalidSamplingRate() + { + String table = createTable("CREATE TABLE %s (id int PRIMARY KEY, data text) WITH compression = {'class': 'ZstdDictionaryCompressor'}"); + + // Test with sampling rate too high + ToolRunner.ToolResult result = invokeNodetool("traincompressiondictionary", + "--async", + "--sampling-rate", "1.5", + keyspace(), + table); + result.asserts() + .failure() + .errorContains("Invalid value for sampling-rate: 1.5. Must be in range (0, 1]"); + + // Test with sampling rate zero + result = invokeNodetool("traincompressiondictionary", + "--async", + "--sampling-rate", "0.0", + keyspace(), + table); + result.asserts() + .failure() + .errorContains("Invalid value for sampling-rate: 0.0. Must be in range (0, 1]"); + + // Test with negative sampling rate + result = invokeNodetool("traincompressiondictionary", + "--async", + "--sampling-rate", "-0.5", + keyspace(), + table); + result.asserts() + .failure() + .errorContains("Invalid value for sampling-rate: -0.5. Must be in range (0, 1]"); + } +}