Skip to content

Commit 23ec1c8

Browse files
michaeljmarshallmaedhroz
authored andcommitted
Refactor SAI ANN query execution to use score ordered iterators for correctness and speed
Rewrites ANN search query execution logic to more efficiently merge graph search results using similarity score-ordered (descending) iterators to merge segments efficiently. Allows for reduced memory consumption during queries, reduced impact of overwrites and tombstones, selective re-querying of minimally necessary graphs, and reduced shuffling of PrimaryKey objects. patch by Michael Marshall; reviewed by Caleb Rackliffe and Michael Semb Wever for CASSANDRA-20086
1 parent a2e5568 commit 23ec1c8

File tree

71 files changed

+3640
-1132
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+3640
-1132
lines changed

CHANGES.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
5.0.7
2+
* Refactor SAI ANN query execution to use score ordered iterators for correctness and speed (CASSANDRA-20086)
23
* Disallow binding an identity to a superuser when the user is a regular user (CASSANDRA-21219)
34
* Fix ConcurrentModificationException in compaction garbagecollect (CASSANDRA-21065)
45
* Dynamically skip sharding L0 when SAI Vector index present (CASSANDRA-19661)

src/java/org/apache/cassandra/config/CassandraRelevantProperties.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -460,14 +460,19 @@ public enum CassandraRelevantProperties
460460
/** Whether to allow the user to specify custom options to the hnsw index */
461461
SAI_VECTOR_ALLOW_CUSTOM_PARAMETERS("cassandra.sai.vector.allow_custom_parameters", "false"),
462462

463-
/** Controls the maximum top-k limit for vector search */
464-
SAI_VECTOR_SEARCH_MAX_TOP_K("cassandra.sai.vector_search.max_top_k", "1000"),
465-
466463
/**
467-
* Controls the maximum number of PrimaryKeys that will be read into memory at one time when ordering/limiting
468-
* the results of an ANN query constrained by non-ANN predicates.
464+
* The maximum number of primary keys that a WHERE clause may materialize before the query planner switches
465+
* from a search-then-sort execution strategy to an order-by-then-filter strategy. Increasing this limit allows
466+
* more primary keys to be buffered in memory, enabling either (a) brute-force sorting or (b) graph traversal
467+
* with a restrictive filter that admits only nodes whose primary keys matched the WHERE clause.
468+
*
469+
* Note also that the SAI_INTERSECTION_CLAUSE_LIMIT is applied to the WHERE clause before using a search to
470+
* build a potential result set for search-then-sort query execution.
469471
*/
470-
SAI_VECTOR_SEARCH_ORDER_CHUNK_SIZE("cassandra.sai.vector_search.order_chunk_size", "100000"),
472+
SAI_VECTOR_SEARCH_MAX_MATERIALIZE_KEYS("cassandra.sai.vector_search.max_materialized_keys", "16000"),
473+
474+
/** Controls the maximum top-k limit for vector search */
475+
SAI_VECTOR_SEARCH_MAX_TOP_K("cassandra.sai.vector_search.max_top_k", "1000"),
471476

472477
SCHEMA_PULL_INTERVAL_MS("cassandra.schema_pull_interval_ms", "60000"),
473478
SCHEMA_UPDATE_HANDLER_FACTORY_CLASS("cassandra.schema.update_handler_factory.class"),
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.cassandra.db;
20+
21+
/**
22+
* A cell's source data object. Can be used to determine if two cells originated from the same object, e.g. memtable
23+
* or sstable.
24+
*/
25+
public interface CellSourceIdentifier
26+
{
27+
/**
28+
* Returns true iff this and other CellSourceIdentifier are equal, indicating that the cell are from the same
29+
* source.
30+
* @param other the other source with which to compare
31+
* @return true if the two sources are equal
32+
*/
33+
default boolean isEqualSource(CellSourceIdentifier other)
34+
{
35+
return this.equals(other);
36+
}
37+
}

src/java/org/apache/cassandra/db/SinglePartitionReadCommand.java

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.NavigableSet;
2727
import java.util.TreeSet;
2828
import java.util.concurrent.TimeUnit;
29+
import java.util.function.Function;
2930
import java.util.stream.Collectors;
3031

3132
import com.google.common.annotations.VisibleForTesting;
@@ -51,6 +52,7 @@
5152
import org.apache.cassandra.db.partitions.PartitionIterators;
5253
import org.apache.cassandra.db.partitions.SingletonUnfilteredPartitionIterator;
5354
import org.apache.cassandra.db.partitions.UnfilteredPartitionIterator;
55+
import org.apache.cassandra.db.rows.BaseRowIterator;
5456
import org.apache.cassandra.db.rows.Cell;
5557
import org.apache.cassandra.db.rows.Row;
5658
import org.apache.cassandra.db.rows.Rows;
@@ -658,10 +660,26 @@ public UnfilteredRowIterator queryMemtableAndDisk(ColumnFamilyStore cfs, ReadExe
658660
assert executionController != null && executionController.validForReadOn(cfs);
659661
Tracing.trace("Executing single-partition query on {}", cfs.name);
660662

661-
return queryMemtableAndDiskInternal(cfs, executionController);
663+
Tracing.trace("Acquiring sstable references");
664+
ColumnFamilyStore.ViewFragment view = cfs.select(View.select(SSTableSet.LIVE, partitionKey()));
665+
return queryMemtableAndDiskInternal(cfs, view, null, executionController);
666+
}
667+
668+
public UnfilteredRowIterator queryMemtableAndDisk(ColumnFamilyStore cfs,
669+
ColumnFamilyStore.ViewFragment view,
670+
Function<CellSourceIdentifier, Transformation<BaseRowIterator<?>>> rowTransformer,
671+
ReadExecutionController executionController)
672+
{
673+
assert executionController != null && executionController.validForReadOn(cfs);
674+
Tracing.trace("Executing single-partition query on {}", cfs.name);
675+
676+
return queryMemtableAndDiskInternal(cfs, view, rowTransformer, executionController);
662677
}
663678

664-
private UnfilteredRowIterator queryMemtableAndDiskInternal(ColumnFamilyStore cfs, ReadExecutionController controller)
679+
private UnfilteredRowIterator queryMemtableAndDiskInternal(ColumnFamilyStore cfs,
680+
ColumnFamilyStore.ViewFragment view,
681+
Function<CellSourceIdentifier, Transformation<BaseRowIterator<?>>> rowTransformer,
682+
ReadExecutionController controller)
665683
{
666684
/*
667685
* We have 2 main strategies:
@@ -685,11 +703,9 @@ private UnfilteredRowIterator queryMemtableAndDiskInternal(ColumnFamilyStore cfs
685703
&& !queriesMulticellType()
686704
&& !controller.isTrackingRepairedStatus())
687705
{
688-
return queryMemtableAndSSTablesInTimestampOrder(cfs, (ClusteringIndexNamesFilter)clusteringIndexFilter(), controller);
706+
return queryMemtableAndSSTablesInTimestampOrder(cfs, view, rowTransformer, (ClusteringIndexNamesFilter)clusteringIndexFilter(), controller);
689707
}
690708

691-
Tracing.trace("Acquiring sstable references");
692-
ColumnFamilyStore.ViewFragment view = cfs.select(View.select(SSTableSet.LIVE, partitionKey()));
693709
view.sstables.sort(SSTableReader.maxTimestampDescending);
694710
ClusteringIndexFilter filter = clusteringIndexFilter();
695711
long minTimestamp = Long.MAX_VALUE;
@@ -708,6 +724,9 @@ private UnfilteredRowIterator queryMemtableAndDiskInternal(ColumnFamilyStore cfs
708724
if (memtable.getMinTimestamp() != Memtable.NO_MIN_TIMESTAMP)
709725
minTimestamp = Math.min(minTimestamp, memtable.getMinTimestamp());
710726

727+
if (rowTransformer != null)
728+
iter = Transformation.apply(iter, rowTransformer.apply(memtable));
729+
711730
// Memtable data is always considered unrepaired
712731
controller.updateMinOldestUnrepairedTombstone(memtable.getMinLocalDeletionTime());
713732
inputCollector.addMemtableIterator(RTBoundValidator.validate(iter, RTBoundValidator.Stage.MEMTABLE, false));
@@ -767,6 +786,9 @@ private UnfilteredRowIterator queryMemtableAndDiskInternal(ColumnFamilyStore cfs
767786
UnfilteredRowIterator iter = intersects ? makeRowIteratorWithLowerBound(cfs, sstable, metricsCollector)
768787
: makeRowIteratorWithSkippedNonStaticContent(cfs, sstable, metricsCollector);
769788

789+
if (rowTransformer != null)
790+
iter = Transformation.apply(iter, rowTransformer.apply(sstable.getId()));
791+
770792
inputCollector.addSSTableIterator(sstable, iter);
771793
mostRecentPartitionTombstone = Math.max(mostRecentPartitionTombstone,
772794
iter.partitionLevelDeletion().markedForDeleteAt());
@@ -789,6 +811,10 @@ private UnfilteredRowIterator queryMemtableAndDiskInternal(ColumnFamilyStore cfs
789811
{
790812
if (!sstable.isRepaired())
791813
controller.updateMinOldestUnrepairedTombstone(sstable.getMinLocalDeletionTime());
814+
815+
if (rowTransformer != null)
816+
iter = Transformation.apply(iter, rowTransformer.apply(sstable.getId()));
817+
792818
inputCollector.addSSTableIterator(sstable, iter);
793819
includedDueToTombstones++;
794820
mostRecentPartitionTombstone = Math.max(mostRecentPartitionTombstone,
@@ -922,11 +948,8 @@ private boolean queriesMulticellType()
922948
* no collection or counters are included).
923949
* This method assumes the filter is a {@code ClusteringIndexNamesFilter}.
924950
*/
925-
private UnfilteredRowIterator queryMemtableAndSSTablesInTimestampOrder(ColumnFamilyStore cfs, ClusteringIndexNamesFilter filter, ReadExecutionController controller)
951+
private UnfilteredRowIterator queryMemtableAndSSTablesInTimestampOrder(ColumnFamilyStore cfs, ColumnFamilyStore.ViewFragment view, Function<CellSourceIdentifier, Transformation<BaseRowIterator<?>>> rowTransformer, ClusteringIndexNamesFilter filter, ReadExecutionController controller)
926952
{
927-
Tracing.trace("Acquiring sstable references");
928-
ColumnFamilyStore.ViewFragment view = cfs.select(View.select(SSTableSet.LIVE, partitionKey()));
929-
930953
ImmutableBTreePartition result = null;
931954
SSTableReadMetricsCollector metricsCollector = new SSTableReadMetricsCollector();
932955

@@ -938,7 +961,9 @@ private UnfilteredRowIterator queryMemtableAndSSTablesInTimestampOrder(ColumnFam
938961
if (iter == null)
939962
continue;
940963

941-
result = add(RTBoundValidator.validate(iter, RTBoundValidator.Stage.MEMTABLE, false),
964+
UnfilteredRowIterator wrapped = rowTransformer != null ? Transformation.apply(iter, rowTransformer.apply(memtable))
965+
: iter;
966+
result = add(RTBoundValidator.validate(wrapped, RTBoundValidator.Stage.MEMTABLE, false),
942967
result,
943968
filter,
944969
false,
@@ -993,7 +1018,10 @@ private UnfilteredRowIterator queryMemtableAndSSTablesInTimestampOrder(ColumnFam
9931018
}
9941019
else
9951020
{
996-
result = add(RTBoundValidator.validate(iter, RTBoundValidator.Stage.SSTABLE, false),
1021+
UnfilteredRowIterator wrapped = rowTransformer != null ? Transformation.apply(iter, rowTransformer.apply(sstable.getId()))
1022+
: iter;
1023+
1024+
result = add(RTBoundValidator.validate(wrapped, RTBoundValidator.Stage.SSTABLE, false),
9971025
result,
9981026
filter,
9991027
sstable.isRepaired(),
@@ -1008,8 +1036,9 @@ private UnfilteredRowIterator queryMemtableAndSSTablesInTimestampOrder(ColumnFam
10081036
{
10091037
if (iter.isEmpty())
10101038
continue;
1011-
1012-
result = add(RTBoundValidator.validate(iter, RTBoundValidator.Stage.SSTABLE, false),
1039+
UnfilteredRowIterator wrapped = rowTransformer != null ? Transformation.apply(iter, rowTransformer.apply(sstable.getId()))
1040+
: iter;
1041+
result = add(RTBoundValidator.validate(wrapped, RTBoundValidator.Stage.SSTABLE, false),
10131042
result,
10141043
filter,
10151044
sstable.isRepaired(),

src/java/org/apache/cassandra/db/lifecycle/Tracker.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ public Memtable switchMemtable(boolean truncating, Memtable newMemtable)
399399
if (truncating)
400400
notifyRenewed(newMemtable);
401401
else
402-
notifySwitched(result.left.getCurrentMemtable());
402+
notifySwitched(result.left.getCurrentMemtable(), result.right.getCurrentMemtable());
403403

404404
return result.left.getCurrentMemtable();
405405
}
@@ -554,9 +554,9 @@ public void notifyRenewed(Memtable renewed)
554554
notify(new MemtableRenewedNotification(renewed));
555555
}
556556

557-
public void notifySwitched(Memtable previous)
557+
public void notifySwitched(Memtable previous, Memtable next)
558558
{
559-
notify(new MemtableSwitchedNotification(previous));
559+
notify(new MemtableSwitchedNotification(previous, next));
560560
}
561561

562562
public void notifyDiscarded(Memtable discarded)

src/java/org/apache/cassandra/db/memtable/Memtable.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import javax.annotation.concurrent.NotThreadSafe;
2424

25+
import org.apache.cassandra.db.CellSourceIdentifier;
2526
import org.apache.cassandra.db.ColumnFamilyStore;
2627
import org.apache.cassandra.db.PartitionPosition;
2728
import org.apache.cassandra.db.RegularAndStaticColumns;
@@ -54,7 +55,7 @@
5455
*
5556
* See Memtable_API.md for details on implementing and using alternative memtable implementations.
5657
*/
57-
public interface Memtable extends Comparable<Memtable>, UnfilteredSource
58+
public interface Memtable extends Comparable<Memtable>, UnfilteredSource, CellSourceIdentifier
5859
{
5960
public static final long NO_MIN_TIMESTAMP = -1;
6061

src/java/org/apache/cassandra/index/sai/QueryContext.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ public class QueryContext
7070
* */
7171
public boolean hasUnrepairedMatches = false;
7272

73-
private VectorQueryContext vectorContext;
74-
7573
public QueryContext(ReadCommand readCommand, long executionQuotaMs)
7674
{
7775
this.readCommand = readCommand;
@@ -93,10 +91,8 @@ public void checkpoint()
9391
}
9492
}
9593

96-
public VectorQueryContext vectorContext()
94+
public int limit()
9795
{
98-
if (vectorContext == null)
99-
vectorContext = new VectorQueryContext(readCommand);
100-
return vectorContext;
96+
return readCommand.limits().count();
10197
}
10298
}

src/java/org/apache/cassandra/index/sai/StorageAttachedIndexGroup.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import org.apache.cassandra.notifications.INotificationConsumer;
6262
import org.apache.cassandra.notifications.MemtableDiscardedNotification;
6363
import org.apache.cassandra.notifications.MemtableRenewedNotification;
64+
import org.apache.cassandra.notifications.MemtableSwitchedNotification;
6465
import org.apache.cassandra.notifications.SSTableAddedNotification;
6566
import org.apache.cassandra.notifications.SSTableListChangedNotification;
6667
import org.apache.cassandra.schema.TableMetadata;
@@ -275,6 +276,10 @@ else if (notification instanceof MemtableRenewedNotification)
275276
{
276277
indexes.forEach(index -> index.memtableIndexManager().renewMemtable(((MemtableRenewedNotification) notification).renewed));
277278
}
279+
else if (notification instanceof MemtableSwitchedNotification)
280+
{
281+
indexes.forEach(index -> index.memtableIndexManager().maybeInitializeMemtableIndex(((MemtableSwitchedNotification) notification).next));
282+
}
278283
else if (notification instanceof MemtableDiscardedNotification)
279284
{
280285
indexes.forEach(index -> index.memtableIndexManager().discardMemtable(((MemtableDiscardedNotification) notification).memtable));

0 commit comments

Comments
 (0)