diff --git a/src/Build5Nines.SharpVector/BasicDiskMemoryVectorDatabaseBase.cs b/src/Build5Nines.SharpVector/BasicDiskMemoryVectorDatabaseBase.cs index 79d2d4e..645bc46 100644 --- a/src/Build5Nines.SharpVector/BasicDiskMemoryVectorDatabaseBase.cs +++ b/src/Build5Nines.SharpVector/BasicDiskMemoryVectorDatabaseBase.cs @@ -3,6 +3,7 @@ using Build5Nines.SharpVector.Vocabulary; using Build5Nines.SharpVector.Vectorization; using Build5Nines.SharpVector.VectorCompare; +using Build5Nines.SharpVector.VectorEncoding; using Build5Nines.SharpVector.VectorStore; namespace Build5Nines.SharpVector; @@ -26,4 +27,8 @@ public abstract class BasicDiskMemoryVectorDatabaseBase + /// Create a disk-backed database that compresses vectors with the supplied + /// encoding before storing them. + /// + public BasicDiskVectorDatabase(string rootPath, IVectorEncoding encoding) + : base( + new BasicDiskVectorStore, string, int>( + rootPath, + new BasicDiskVocabularyStore(rootPath) + ), + encoding + ) + { } + [Obsolete("Use DeserializeFromBinaryStreamAsync instead.")] public override async Task DeserializeFromJsonStreamAsync(Stream stream) { diff --git a/src/Build5Nines.SharpVector/BasicMemoryVectorDatabase.cs b/src/Build5Nines.SharpVector/BasicMemoryVectorDatabase.cs index 6f3a33c..f1a7ac0 100644 --- a/src/Build5Nines.SharpVector/BasicMemoryVectorDatabase.cs +++ b/src/Build5Nines.SharpVector/BasicMemoryVectorDatabase.cs @@ -1,7 +1,21 @@ +using Build5Nines.SharpVector.VectorEncoding; + namespace Build5Nines.SharpVector; /// /// A basic implementation of an vector database that uses an in-memory dictionary to store vectors, with integer keys and string metadata values. /// public class BasicMemoryVectorDatabase : MemoryVectorDatabase -{ } \ No newline at end of file +{ + public BasicMemoryVectorDatabase() + : base() + { } + + /// + /// Create a database that compresses vectors with the supplied encoding + /// before storing them. + /// + public BasicMemoryVectorDatabase(IVectorEncoding encoding) + : base(encoding) + { } +} diff --git a/src/Build5Nines.SharpVector/DatabaseInfo.cs b/src/Build5Nines.SharpVector/DatabaseInfo.cs index 9e5c0af..9a09546 100644 --- a/src/Build5Nines.SharpVector/DatabaseInfo.cs +++ b/src/Build5Nines.SharpVector/DatabaseInfo.cs @@ -1,3 +1,5 @@ +using System.Text.Json.Serialization; + namespace Build5Nines.SharpVector; public class DatabaseInfo @@ -22,4 +24,12 @@ public DatabaseInfo(string? schema, string? version, string? classType) public string? Schema { get; set; } public string? Version { get; set; } public string? ClassType { get; set; } -} \ No newline at end of file + + /// + /// The id of the vector encoding used for this database, when not the + /// default raw float32 encoding. Omitted from the JSON when null so + /// raw-encoded databases continue to produce the legacy on-disk shape. + /// + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? VectorEncodingId { get; set; } +} diff --git a/src/Build5Nines.SharpVector/MemoryVectorDatabase.cs b/src/Build5Nines.SharpVector/MemoryVectorDatabase.cs index e08b425..b2ae919 100644 --- a/src/Build5Nines.SharpVector/MemoryVectorDatabase.cs +++ b/src/Build5Nines.SharpVector/MemoryVectorDatabase.cs @@ -3,6 +3,7 @@ using Build5Nines.SharpVector.Preprocessing; using Build5Nines.SharpVector.Vectorization; using Build5Nines.SharpVector.VectorCompare; +using Build5Nines.SharpVector.VectorEncoding; using Build5Nines.SharpVector.VectorStore; namespace Build5Nines.SharpVector; @@ -37,6 +38,19 @@ public MemoryVectorDatabase() ) { } + /// + /// Create an in-memory vector database that compresses vectors with the + /// supplied encoding before storing them. + /// + public MemoryVectorDatabase(IVectorEncoding encoding) + : base( + new MemoryDictionaryVectorStoreWithVocabulary, string, int>( + new DictionaryVocabularyStore() + ), + encoding + ) + { } + [Obsolete("Use DeserializeFromBinaryStreamAsync instead.")] public override async Task DeserializeFromJsonStreamAsync(Stream stream) diff --git a/src/Build5Nines.SharpVector/MemoryVectorDatabaseBase.cs b/src/Build5Nines.SharpVector/MemoryVectorDatabaseBase.cs index 6d407d7..749567a 100644 --- a/src/Build5Nines.SharpVector/MemoryVectorDatabaseBase.cs +++ b/src/Build5Nines.SharpVector/MemoryVectorDatabaseBase.cs @@ -3,6 +3,7 @@ using Build5Nines.SharpVector.Vocabulary; using Build5Nines.SharpVector.Vectorization; using Build5Nines.SharpVector.VectorCompare; +using Build5Nines.SharpVector.VectorEncoding; using Build5Nines.SharpVector.VectorStore; using System.Collections.Concurrent; using System.IO.Compression; @@ -43,6 +44,10 @@ public abstract class MemoryVectorDatabaseBase @@ -63,4 +68,8 @@ public abstract class MemoryVectorDatabaseBase VectorComparison.CosineSimilarity; + /// /// Calculates the cosine similarity between two vectors. /// diff --git a/src/Build5Nines.SharpVector/VectorCompare/EuclideanDistanceVectorComparerAsync.cs b/src/Build5Nines.SharpVector/VectorCompare/EuclideanDistanceVectorComparerAsync.cs index 5c425cf..9981988 100644 --- a/src/Build5Nines.SharpVector/VectorCompare/EuclideanDistanceVectorComparerAsync.cs +++ b/src/Build5Nines.SharpVector/VectorCompare/EuclideanDistanceVectorComparerAsync.cs @@ -4,6 +4,8 @@ namespace Build5Nines.SharpVector.VectorCompare; public class EuclideanDistanceVectorComparer : IVectorComparer { + public VectorComparison MetricKind => VectorComparison.EuclideanDistance; + /// /// Calculates the Euclidean distance between two vectors. /// diff --git a/src/Build5Nines.SharpVector/VectorCompare/IVectorComparer.cs b/src/Build5Nines.SharpVector/VectorCompare/IVectorComparer.cs index b6e20b7..d32a223 100644 --- a/src/Build5Nines.SharpVector/VectorCompare/IVectorComparer.cs +++ b/src/Build5Nines.SharpVector/VectorCompare/IVectorComparer.cs @@ -2,6 +2,13 @@ namespace Build5Nines.SharpVector.VectorCompare; public interface IVectorComparer { + /// + /// The kind of metric this comparer implements. The database uses this to + /// dispatch to the matching fast-path in an + /// when stored vectors are compressed. + /// + VectorComparison MetricKind { get; } + /// /// Calculates a comparison between two vectors /// diff --git a/src/Build5Nines.SharpVector/VectorComparison.cs b/src/Build5Nines.SharpVector/VectorComparison.cs index 572ae5e..4fe1385 100644 --- a/src/Build5Nines.SharpVector/VectorComparison.cs +++ b/src/Build5Nines.SharpVector/VectorComparison.cs @@ -1,3 +1,15 @@ namespace Build5Nines.SharpVector; -//public record VectorComparison(TId Id, float vectorComparison); +/// +/// The kind of similarity/distance metric a +/// represents. Used by the encoding subsystem to dispatch to the correct +/// fast-path implementation for an encoded vector. +/// +public enum VectorComparison +{ + /// Cosine similarity: higher is more similar, range [-1, 1]. + CosineSimilarity, + + /// Euclidean distance: lower is more similar, range [0, infinity). + EuclideanDistance +} diff --git a/src/Build5Nines.SharpVector/VectorDatabaseBase.cs b/src/Build5Nines.SharpVector/VectorDatabaseBase.cs index 203cb58..79ade7a 100644 --- a/src/Build5Nines.SharpVector/VectorDatabaseBase.cs +++ b/src/Build5Nines.SharpVector/VectorDatabaseBase.cs @@ -3,6 +3,7 @@ using Build5Nines.SharpVector.Vocabulary; using Build5Nines.SharpVector.Vectorization; using Build5Nines.SharpVector.VectorCompare; +using Build5Nines.SharpVector.VectorEncoding; using Build5Nines.SharpVector.VectorStore; using System.Collections.Concurrent; using System.IO.Compression; @@ -53,13 +54,24 @@ public abstract class VectorDatabaseBase protected TVectorStore VectorStore { get; private set; } + /// + /// The encoding applied to vectors as they are written into the store. + /// On load, this is overridden by the encoding recorded in the database file. + /// + public IVectorEncoding VectorEncoding { get; protected internal set; } + public VectorDatabaseBase(TVectorStore vectorStore) + : this(vectorStore, RawFloat32Encoding.Instance) + { } + + public VectorDatabaseBase(TVectorStore vectorStore, IVectorEncoding encoding) { VectorStore = vectorStore; _idGenerator = new TIdGenerator(); _textPreprocessor = new TTextPreprocessor(); _vectorizer = new TVectorizer(); _vectorComparer = new TVectorComparer(); + VectorEncoding = encoding ?? RawFloat32Encoding.Instance; } /// @@ -98,11 +110,12 @@ public IEnumerable GetIds() // Generate the vector asynchronously float[] vector = await _vectorizer.GenerateVectorFromTokensAsync(VectorStore.VocabularyStore, tokens); - + // Generate the ID and store the vector text item asynchronously TId id = _idGenerator.NewId(); - await VectorStore.SetAsync(id, new VectorTextItem(text, metadata, vector)); - + var encoded = VectorEncoding.Encode(vector); + await VectorStore.SetAsync(id, new VectorTextItem(text, metadata, encoded)); + return id; } @@ -157,7 +170,8 @@ public void UpdateText(TId id, TVocabularyKey text) VectorStore.VocabularyStore.Update(tokens); float[] vector = _vectorizer.GenerateVectorFromTokens(VectorStore.VocabularyStore, tokens); var metadata = VectorStore.Get(id).Metadata; - VectorStore.Set(id, new VectorTextItem(text, metadata, vector)); + var encoded = VectorEncoding.Encode(vector); + VectorStore.Set(id, new VectorTextItem(text, metadata, encoded)); } else { @@ -179,9 +193,9 @@ public void UpdateTextMetadata(TId id, TMetadata metadata) { var item = new VectorTextItem( existing.Text, metadata, - existing.Vector + existing.EncodedVector ); - + VectorStore.Set(id, item); } else @@ -203,7 +217,8 @@ public void UpdateTextAndMetadata(TId id, TVocabularyKey text, TMetadata metadat var tokens = _textPreprocessor.TokenizeAndPreprocess(text); VectorStore.VocabularyStore.Update(tokens); float[] vector = _vectorizer.GenerateVectorFromTokens(VectorStore.VocabularyStore, tokens); - VectorStore.Set(id, new VectorTextItem(text, metadata, vector)); + var encoded = VectorEncoding.Encode(vector); + VectorStore.Set(id, new VectorTextItem(text, metadata, encoded)); } else { @@ -314,9 +329,15 @@ public virtual async Task SerializeToBinaryStreamAsync(Stream stream) await Task.WhenAll(taskVectorStore, taskVocabularyStore); + var info = new DatabaseInfo(this.GetType().FullName); + if (VectorEncoding.Id != RawFloat32Encoding.EncodingId) + { + info.VectorEncodingId = VectorEncoding.Id; + } + await DatabaseFile.SaveDatabaseToZipArchiveAsync( stream, - new DatabaseInfo(this.GetType().FullName), + info, async (archive) => { var entryVectorStore = archive.CreateEntry(DatabaseFile.vectorStoreFilename); @@ -361,7 +382,7 @@ public virtual async Task DeserializeFromJsonStreamAsync(Stream stream) { public virtual async Task DeserializeFromBinaryStreamAsync(Stream stream) { - await DatabaseFile.LoadDatabaseFromZipArchiveAsync( + var info = await DatabaseFile.LoadDatabaseFromZipArchiveAsync( stream, this.GetType().FullName, async (archive) => @@ -381,6 +402,13 @@ await DatabaseFile.LoadDatabaseFromZipArchiveAsync( } } ); + + // If the file recorded an explicit encoding, switch the database to use + // it for any subsequent writes. Absence implies the legacy raw encoding. + if (!string.IsNullOrEmpty(info.VectorEncodingId)) + { + VectorEncoding = VectorEncodingRegistry.Get(info.VectorEncodingId); + } } @@ -438,12 +466,23 @@ public abstract class VectorDatabaseBase + /// The encoding applied to vectors as they are written into the store. + /// On load, this is overridden by the encoding recorded in the database file. + /// + public IVectorEncoding VectorEncoding { get; protected internal set; } + public VectorDatabaseBase(IEmbeddingsGenerator embeddingsGenerator, TVectorStore vectorStore) + : this(embeddingsGenerator, vectorStore, RawFloat32Encoding.Instance) + { } + + public VectorDatabaseBase(IEmbeddingsGenerator embeddingsGenerator, TVectorStore vectorStore, IVectorEncoding encoding) { EmbeddingsGenerator = embeddingsGenerator; VectorStore = vectorStore; _idGenerator = new TIdGenerator(); _vectorComparer = new TVectorComparer(); + VectorEncoding = encoding ?? RawFloat32Encoding.Instance; } /// @@ -473,14 +512,15 @@ public IEnumerable GetIds() /// /// public async Task AddTextAsync(string text, TMetadata? metadata = default(TMetadata)) - { + { // Generate the vector asynchronously var vector = await EmbeddingsGenerator.GenerateEmbeddingsAsync(text); - + // Generate the ID and store the vector text item asynchronously TId id = _idGenerator.NewId(); - await VectorStore.SetAsync(id, new VectorTextItem(text, metadata, vector)); - + var encoded = VectorEncoding.Encode(vector); + await VectorStore.SetAsync(id, new VectorTextItem(text, metadata, encoded)); + return id; } @@ -520,7 +560,8 @@ public async Task> AddTextsAsync(IEnumerable<(string text, TM { TId id = _idGenerator.NewId(); ids.Add(id); - await VectorStore.SetAsync(id, new VectorTextItem(list[i].text, list[i].metadata, vectors[i])); + var encoded = VectorEncoding.Encode(vectors[i]); + await VectorStore.SetAsync(id, new VectorTextItem(list[i].text, list[i].metadata, encoded)); } return ids; @@ -559,8 +600,8 @@ public async Task UpdateTextAsync(TId id, string text) { var existing = VectorStore.Get(id); var vector = await EmbeddingsGenerator.GenerateEmbeddingsAsync(text); - var metadata = existing.Metadata; - VectorStore.Set(id, new VectorTextItem(text, existing.Metadata, vector)); + var encoded = VectorEncoding.Encode(vector); + VectorStore.Set(id, new VectorTextItem(text, existing.Metadata, encoded)); } else { @@ -593,9 +634,9 @@ public void UpdateTextMetadata(TId id, TMetadata metadata) { var item = new VectorTextItem( existing.Text, metadata, - existing.Vector + existing.EncodedVector ); - + VectorStore.Set(id, item); } else @@ -615,7 +656,8 @@ public async Task UpdateTextAndMetadataAsync(TId id, string text, TMetadata meta if (VectorStore.ContainsKey(id)) { var vector = await EmbeddingsGenerator.GenerateEmbeddingsAsync(text); - VectorStore.Set(id, new VectorTextItem(text, metadata, vector)); + var encoded = VectorEncoding.Encode(vector); + VectorStore.Set(id, new VectorTextItem(text, metadata, encoded)); } else { @@ -692,14 +734,20 @@ private async Task>> C throw new InvalidOperationException("The database is empty."); } + var metricKind = _vectorComparer.MetricKind; var results = new ConcurrentBag>(); await foreach (var kvp in VectorStore) { if (filter == null || await filter(kvp.Value.Metadata)) { var item = kvp.Value; - - float vectorComparisonValue = await _vectorComparer.CalculateAsync(queryVector, item.Vector); + var stored = item.EncodedVector; + // Dispatch to the encoding's fast asymmetric distance path. For + // raw encoding this is equivalent to the comparer's float/float + // calculation; for compressed encodings it avoids decoding back + // to float[] at search time. + var encoding = VectorEncodingRegistry.TryGet(stored.EncodingId, out var e) ? e : VectorEncoding; + float vectorComparisonValue = encoding.Compare(metricKind, queryVector, stored); if (_vectorComparer.IsWithinThreshold(threshold, vectorComparisonValue)) { @@ -730,9 +778,15 @@ public virtual async Task SerializeToBinaryStreamAsync(Stream stream) var streamVectorStore = new MemoryStream(); await VectorStore.SerializeToJsonStreamAsync(streamVectorStore); + var info = new DatabaseInfo(this.GetType().FullName); + if (VectorEncoding.Id != RawFloat32Encoding.EncodingId) + { + info.VectorEncodingId = VectorEncoding.Id; + } + await DatabaseFile.SaveDatabaseToZipArchiveAsync( stream, - new DatabaseInfo(this.GetType().FullName), + info, async (archive) => { var entryVectorStore = archive.CreateEntry(DatabaseFile.vectorStoreFilename); @@ -770,7 +824,7 @@ public virtual async Task DeserializeFromJsonStreamAsync(Stream stream) public virtual async Task DeserializeFromBinaryStreamAsync(Stream stream) { - await DatabaseFile.LoadDatabaseFromZipArchiveAsync( + var info = await DatabaseFile.LoadDatabaseFromZipArchiveAsync( stream, this.GetType().FullName, async (archive) => @@ -789,6 +843,11 @@ await DatabaseFile.LoadDatabaseFromZipArchiveAsync( } } ); + + if (!string.IsNullOrEmpty(info.VectorEncodingId)) + { + VectorEncoding = VectorEncodingRegistry.Get(info.VectorEncodingId); + } } [Obsolete("Use DeserializeFromBinaryStream Instead")] diff --git a/src/Build5Nines.SharpVector/VectorEncoding/IEncodedVector.cs b/src/Build5Nines.SharpVector/VectorEncoding/IEncodedVector.cs new file mode 100644 index 0000000..4932021 --- /dev/null +++ b/src/Build5Nines.SharpVector/VectorEncoding/IEncodedVector.cs @@ -0,0 +1,32 @@ +namespace Build5Nines.SharpVector.VectorEncoding; + +/// +/// A vector that has been encoded by an . +/// The encoded form is what the database actually stores; the original +/// floating-point values are recoverable via . +/// +public interface IEncodedVector +{ + /// + /// Identifier of the encoding that produced this vector. Used to look up + /// the matching when deserializing. + /// + string EncodingId { get; } + + /// + /// Logical dimensionality of the original float vector. + /// + int Dimensions { get; } + + /// + /// Raw bytes of the encoded payload, suitable for persistence. + /// + byte[] GetBytes(); + + /// + /// Reconstructs an approximation of the original float vector. + /// For lossless encodings this is exact; for compressed encodings it + /// is a lossy approximation. + /// + float[] Decode(); +} diff --git a/src/Build5Nines.SharpVector/VectorEncoding/IVectorEncoding.cs b/src/Build5Nines.SharpVector/VectorEncoding/IVectorEncoding.cs new file mode 100644 index 0000000..342ee25 --- /dev/null +++ b/src/Build5Nines.SharpVector/VectorEncoding/IVectorEncoding.cs @@ -0,0 +1,41 @@ +namespace Build5Nines.SharpVector.VectorEncoding; + +/// +/// A strategy for compressing/encoding float vectors and computing +/// similarity directly against the encoded form. +/// +/// +/// Implementations are responsible for the full round-trip of one encoded +/// vector type: encode a float[] into bytes, restore an encoded vector from +/// bytes, decode back to float[] when needed, and compute similarity between +/// a float query vector and an encoded stored vector for each supported +/// metric. +/// +public interface IVectorEncoding +{ + /// + /// Stable identifier used to tag persisted vectors and look the encoding + /// up via . + /// + string Id { get; } + + /// + /// Encode a float vector into the encoding's compressed form. + /// + IEncodedVector Encode(float[] vector); + + /// + /// Reconstruct an encoded vector from previously persisted bytes. + /// + /// The raw payload produced by . + /// The original float vector dimensionality. + IEncodedVector LoadFromBytes(byte[] bytes, int dimensions); + + /// + /// Compute similarity between a float query vector and a stored encoded + /// vector using the requested metric. Implementations should use the fast + /// asymmetric path (query stays float, stored stays encoded) whenever + /// possible; otherwise fall back to . + /// + float Compare(VectorComparison metric, float[] query, IEncodedVector encoded); +} diff --git a/src/Build5Nines.SharpVector/VectorEncoding/Int8ScalarQuantizationEncoding.cs b/src/Build5Nines.SharpVector/VectorEncoding/Int8ScalarQuantizationEncoding.cs new file mode 100644 index 0000000..38a7d5f --- /dev/null +++ b/src/Build5Nines.SharpVector/VectorEncoding/Int8ScalarQuantizationEncoding.cs @@ -0,0 +1,161 @@ +using System; +using System.Buffers.Binary; + +namespace Build5Nines.SharpVector.VectorEncoding; + +/// +/// Symmetric per-vector int8 scalar quantization. Each vector chooses its own +/// scale equal to its absolute-max value divided by 127, then quantizes every +/// component into a signed byte. Storage shrinks from 4 bytes/dim to roughly +/// 1 byte/dim (plus a single 4-byte scale per vector). +/// +/// +/// For cosine similarity the per-vector scale cancels out, so quality loss is +/// limited to the rounding error of the int8 codes. Euclidean distance +/// reconstructs values as code * scale before differencing. +/// +public sealed class Int8ScalarQuantizationEncoding : IVectorEncoding +{ + public const string EncodingId = "int8-sq"; + + public static readonly Int8ScalarQuantizationEncoding Instance = new(); + + public string Id => EncodingId; + + public IEncodedVector Encode(float[] vector) + { + if (vector is null) throw new ArgumentNullException(nameof(vector)); + + float absMax = 0f; + for (int i = 0; i < vector.Length; i++) + { + float a = Math.Abs(vector[i]); + if (a > absMax) absMax = a; + } + + // A zero vector encodes as all-zero codes with scale 0; decode round-trips to zero. + float scale = absMax / 127f; + var codes = new sbyte[vector.Length]; + if (scale > 0f) + { + float inv = 1f / scale; + for (int i = 0; i < vector.Length; i++) + { + int q = (int)MathF.Round(vector[i] * inv); + if (q > 127) q = 127; + else if (q < -127) q = -127; + codes[i] = (sbyte)q; + } + } + + return new Int8EncodedVector(scale, codes); + } + + public IEncodedVector LoadFromBytes(byte[] bytes, int dimensions) + { + if (bytes is null) throw new ArgumentNullException(nameof(bytes)); + if (bytes.Length != sizeof(float) + dimensions) + throw new ArgumentException( + $"Expected {sizeof(float) + dimensions} bytes for int8-sq of {dimensions} dims, got {bytes.Length}.", + nameof(bytes)); + + float scale = BinaryPrimitives.ReadSingleLittleEndian(bytes.AsSpan(0, sizeof(float))); + var codes = new sbyte[dimensions]; + for (int i = 0; i < dimensions; i++) + { + codes[i] = (sbyte)bytes[sizeof(float) + i]; + } + return new Int8EncodedVector(scale, codes); + } + + public float Compare(VectorComparison metric, float[] query, IEncodedVector encoded) + { + if (encoded is not Int8EncodedVector q) + throw new ArgumentException( + $"Int8ScalarQuantizationEncoding cannot compare against encoding '{encoded.EncodingId}'.", + nameof(encoded)); + + if (query.Length != q.Codes.Length) + throw new ArgumentException("Vectors must be of the same length."); + + return metric switch + { + VectorComparison.CosineSimilarity => CosineSimilarity(query, q), + VectorComparison.EuclideanDistance => EuclideanDistance(query, q), + _ => throw new ArgumentOutOfRangeException(nameof(metric), metric, null) + }; + } + + private static float CosineSimilarity(float[] query, Int8EncodedVector stored) + { + // cos = sum(q_i * c_i) / (|query| * sqrt(sum(c_i^2))) + // The scale factor cancels because it appears identically in numerator and + // in the magnitude of the decoded stored vector. + float dot = 0f; + long codeSqSum = 0; + float qSqSum = 0f; + var codes = stored.Codes; + for (int i = 0; i < codes.Length; i++) + { + float qi = query[i]; + int ci = codes[i]; + dot += qi * ci; + codeSqSum += ci * ci; + qSqSum += qi * qi; + } + float magQuery = (float)Math.Sqrt(qSqSum); + float magCodes = (float)Math.Sqrt(codeSqSum); + if (magQuery == 0f || magCodes == 0f) return 0f; + return dot / (magQuery * magCodes); + } + + private static float EuclideanDistance(float[] query, Int8EncodedVector stored) + { + float scale = stored.Scale; + var codes = stored.Codes; + float sum = 0f; + for (int i = 0; i < codes.Length; i++) + { + float d = query[i] - codes[i] * scale; + sum += d * d; + } + return (float)Math.Sqrt(sum); + } + + private sealed class Int8EncodedVector : IEncodedVector + { + internal readonly float Scale; + internal readonly sbyte[] Codes; + + public Int8EncodedVector(float scale, sbyte[] codes) + { + Scale = scale; + Codes = codes; + } + + public string EncodingId => Int8ScalarQuantizationEncoding.EncodingId; + + public int Dimensions => Codes.Length; + + public byte[] GetBytes() + { + var bytes = new byte[sizeof(float) + Codes.Length]; + BinaryPrimitives.WriteSingleLittleEndian(bytes.AsSpan(0, sizeof(float)), Scale); + for (int i = 0; i < Codes.Length; i++) + { + bytes[sizeof(float) + i] = (byte)Codes[i]; + } + return bytes; + } + + public float[] Decode() + { + var values = new float[Codes.Length]; + for (int i = 0; i < Codes.Length; i++) + { + values[i] = Codes[i] * Scale; + } + return values; + } + } +} diff --git a/src/Build5Nines.SharpVector/VectorEncoding/RaBitQEncoding.cs b/src/Build5Nines.SharpVector/VectorEncoding/RaBitQEncoding.cs new file mode 100644 index 0000000..15dc856 --- /dev/null +++ b/src/Build5Nines.SharpVector/VectorEncoding/RaBitQEncoding.cs @@ -0,0 +1,188 @@ +using System; +using System.Buffers.Binary; + +namespace Build5Nines.SharpVector.VectorEncoding; + +/// +/// Rotation-free RaBitQ-style 1-bit binary quantization. +/// Each vector is stored as a sign-bit code plus two scalar correction terms +/// (its L2 norm and a per-vector reconstruction factor) so that asymmetric +/// inner-product / cosine similarity against a float query vector can be +/// recovered with reasonable accuracy. +/// +/// +/// The published RaBitQ algorithm pre-rotates database and query vectors with +/// a shared random orthonormal matrix before quantizing, which gives the +/// unbiased estimator strong concentration bounds for arbitrary input +/// distributions. This implementation omits the rotation because the +/// abstraction is a registry-managed singleton +/// with no per-database state; for already-isotropic embedding outputs (the +/// typical input for this library) the rotation-free estimator is still close +/// to the rotated variant in practice. +/// +/// Storage per vector: 8 bytes of scalar correction + ceil(D / 8) bytes of +/// packed sign bits — roughly 1 bit per dimension, a ~32x reduction over +/// raw float32 for high-dimensional embeddings. +/// +public sealed class RaBitQEncoding : IVectorEncoding +{ + public const string EncodingId = "rabitq"; + + public static readonly RaBitQEncoding Instance = new(); + + public string Id => EncodingId; + + public IEncodedVector Encode(float[] vector) + { + if (vector is null) throw new ArgumentNullException(nameof(vector)); + int d = vector.Length; + + float normSq = 0f; + float sumAbs = 0f; + for (int i = 0; i < d; i++) + { + normSq += vector[i] * vector[i]; + sumAbs += MathF.Abs(vector[i]); + } + float norm = MathF.Sqrt(normSq); + + // correction = + // = (sum(|v_i|) / ||v||) / sqrt(D) + // For a zero vector both norm and correction are zero; the decoder + // treats this as a zero estimate. + float correction = (norm > 0f) + ? (sumAbs / norm) / MathF.Sqrt(d) + : 0f; + + var bits = new byte[(d + 7) / 8]; + for (int i = 0; i < d; i++) + { + if (vector[i] >= 0f) + { + bits[i >> 3] |= (byte)(1 << (i & 7)); + } + } + + return new RaBitQVector(norm, correction, bits, d); + } + + public IEncodedVector LoadFromBytes(byte[] bytes, int dimensions) + { + if (bytes is null) throw new ArgumentNullException(nameof(bytes)); + int expected = 2 * sizeof(float) + (dimensions + 7) / 8; + if (bytes.Length != expected) + throw new ArgumentException( + $"Expected {expected} bytes for rabitq of {dimensions} dims, got {bytes.Length}.", + nameof(bytes)); + + float norm = BinaryPrimitives.ReadSingleLittleEndian(bytes.AsSpan(0, sizeof(float))); + float correction = BinaryPrimitives.ReadSingleLittleEndian(bytes.AsSpan(sizeof(float), sizeof(float))); + var bits = new byte[(dimensions + 7) / 8]; + Buffer.BlockCopy(bytes, 2 * sizeof(float), bits, 0, bits.Length); + return new RaBitQVector(norm, correction, bits, dimensions); + } + + public float Compare(VectorComparison metric, float[] query, IEncodedVector encoded) + { + if (encoded is not RaBitQVector rab) + throw new ArgumentException( + $"RaBitQEncoding cannot compare against encoding '{encoded.EncodingId}'.", + nameof(encoded)); + if (query.Length != rab.Dimensions) + throw new ArgumentException("Vectors must be of the same length."); + + // Compute where c_i ∈ {+1, -1} from the packed sign bits. + float qDotC = 0f; + float qNormSq = 0f; + var bits = rab.Bits; + int d = rab.Dimensions; + for (int i = 0; i < d; i++) + { + float sign = ((bits[i >> 3] >> (i & 7)) & 1) == 1 ? 1f : -1f; + qDotC += query[i] * sign; + qNormSq += query[i] * query[i]; + } + + // Estimate the cosine between the query and the unit-normalized stored + // vector: ≈ (qDotC / sqrt(D)) / correction, where + // q_hat = q / ||q||. Multiply through by ||q|| * ||d|| to get the + // estimated inner product . + float estDot; + if (rab.Correction > 0f && rab.Norm > 0f) + { + float invSqrtD = 1f / MathF.Sqrt(d); + float estCosineUnit = (qDotC * invSqrtD) / rab.Correction; + estDot = rab.Norm * estCosineUnit; + } + else + { + estDot = 0f; + } + + return metric switch + { + VectorComparison.CosineSimilarity => CosineFromEstDot(estDot, qNormSq, rab.Norm), + VectorComparison.EuclideanDistance => EuclideanFromEstDot(estDot, qNormSq, rab.Norm), + _ => throw new ArgumentOutOfRangeException(nameof(metric), metric, null) + }; + } + + private static float CosineFromEstDot(float estDot, float qNormSq, float dNorm) + { + float qNorm = MathF.Sqrt(qNormSq); + if (qNorm == 0f || dNorm == 0f) return 0f; + return estDot / (qNorm * dNorm); + } + + private static float EuclideanFromEstDot(float estDot, float qNormSq, float dNorm) + { + // ||q - d||² = ||q||² + ||d||² - 2 + float sq = qNormSq + dNorm * dNorm - 2f * estDot; + if (sq < 0f) sq = 0f; + return MathF.Sqrt(sq); + } + + private sealed class RaBitQVector : IEncodedVector + { + internal readonly float Norm; + internal readonly float Correction; + internal readonly byte[] Bits; + + public RaBitQVector(float norm, float correction, byte[] bits, int dimensions) + { + Norm = norm; + Correction = correction; + Bits = bits; + Dimensions = dimensions; + } + + public string EncodingId => RaBitQEncoding.EncodingId; + + public int Dimensions { get; } + + public byte[] GetBytes() + { + var bytes = new byte[2 * sizeof(float) + Bits.Length]; + BinaryPrimitives.WriteSingleLittleEndian(bytes.AsSpan(0, sizeof(float)), Norm); + BinaryPrimitives.WriteSingleLittleEndian(bytes.AsSpan(sizeof(float), sizeof(float)), Correction); + Buffer.BlockCopy(Bits, 0, bytes, 2 * sizeof(float), Bits.Length); + return bytes; + } + + public float[] Decode() + { + // Reconstruct an approximation: each dim recovers as + // sign_i * (||d|| / sqrt(D)) + // which is the best single-magnitude reconstruction given only the + // sign bit and the L2 norm. + var values = new float[Dimensions]; + float magnitude = (Dimensions > 0) ? Norm / MathF.Sqrt(Dimensions) : 0f; + for (int i = 0; i < Dimensions; i++) + { + float sign = ((Bits[i >> 3] >> (i & 7)) & 1) == 1 ? 1f : -1f; + values[i] = sign * magnitude; + } + return values; + } + } +} diff --git a/src/Build5Nines.SharpVector/VectorEncoding/RawFloat32Encoding.cs b/src/Build5Nines.SharpVector/VectorEncoding/RawFloat32Encoding.cs new file mode 100644 index 0000000..8153b7f --- /dev/null +++ b/src/Build5Nines.SharpVector/VectorEncoding/RawFloat32Encoding.cs @@ -0,0 +1,114 @@ +using System; +using System.Buffers.Binary; + +namespace Build5Nines.SharpVector.VectorEncoding; + +/// +/// Lossless passthrough encoding that stores vectors as their original +/// float32 values. This is the default encoding and produces the same +/// on-disk representation as previous versions of the library when +/// combined with the legacy JSON serializer. +/// +public sealed class RawFloat32Encoding : IVectorEncoding +{ + public const string EncodingId = "raw-f32"; + + public static readonly RawFloat32Encoding Instance = new(); + + public string Id => EncodingId; + + public IEncodedVector Encode(float[] vector) + { + if (vector is null) throw new ArgumentNullException(nameof(vector)); + // Defensive copy so caller mutations don't bleed into the store. + var copy = new float[vector.Length]; + Buffer.BlockCopy(vector, 0, copy, 0, vector.Length * sizeof(float)); + return new RawEncodedVector(copy); + } + + public IEncodedVector LoadFromBytes(byte[] bytes, int dimensions) + { + if (bytes is null) throw new ArgumentNullException(nameof(bytes)); + if (bytes.Length != dimensions * sizeof(float)) + throw new ArgumentException( + $"Expected {dimensions * sizeof(float)} bytes for {dimensions} float32 values, got {bytes.Length}.", + nameof(bytes)); + + var values = new float[dimensions]; + Buffer.BlockCopy(bytes, 0, values, 0, bytes.Length); + return new RawEncodedVector(values); + } + + public float Compare(VectorComparison metric, float[] query, IEncodedVector encoded) + { + if (encoded is not RawEncodedVector raw) + raw = new RawEncodedVector(encoded.Decode()); + + return metric switch + { + VectorComparison.CosineSimilarity => CosineSimilarity(query, raw.Values), + VectorComparison.EuclideanDistance => EuclideanDistance(query, raw.Values), + _ => throw new ArgumentOutOfRangeException(nameof(metric), metric, null) + }; + } + + internal static float CosineSimilarity(float[] a, float[] b) + { + if (a.Length != b.Length) + throw new ArgumentException("Vectors must be of the same length."); + + float dot = 0, magA = 0, magB = 0; + for (int i = 0; i < a.Length; i++) + { + dot += a[i] * b[i]; + magA += a[i] * a[i]; + magB += b[i] * b[i]; + } + magA = (float)Math.Sqrt(magA); + magB = (float)Math.Sqrt(magB); + if (magA == 0f || magB == 0f) return 0f; + return dot / (magA * magB); + } + + internal static float EuclideanDistance(float[] a, float[] b) + { + if (a.Length != b.Length) + throw new ArgumentException("Vectors must be of the same length."); + + float sum = 0; + for (int i = 0; i < a.Length; i++) + { + float d = a[i] - b[i]; + sum += d * d; + } + return (float)Math.Sqrt(sum); + } + + private sealed class RawEncodedVector : IEncodedVector + { + internal readonly float[] Values; + + public RawEncodedVector(float[] values) + { + Values = values; + } + + public string EncodingId => RawFloat32Encoding.EncodingId; + + public int Dimensions => Values.Length; + + public byte[] GetBytes() + { + var bytes = new byte[Values.Length * sizeof(float)]; + Buffer.BlockCopy(Values, 0, bytes, 0, bytes.Length); + return bytes; + } + + public float[] Decode() + { + var copy = new float[Values.Length]; + Buffer.BlockCopy(Values, 0, copy, 0, Values.Length * sizeof(float)); + return copy; + } + } +} diff --git a/src/Build5Nines.SharpVector/VectorEncoding/TurboQuantEncoding.cs b/src/Build5Nines.SharpVector/VectorEncoding/TurboQuantEncoding.cs new file mode 100644 index 0000000..4288c22 --- /dev/null +++ b/src/Build5Nines.SharpVector/VectorEncoding/TurboQuantEncoding.cs @@ -0,0 +1,186 @@ +using System; +using System.Buffers.Binary; + +namespace Build5Nines.SharpVector.VectorEncoding; + +/// +/// 4-bit symmetric scalar quantization with two codes packed per byte. +/// Each vector stores a single per-vector scale plus a packed nibble stream. +/// Storage shrinks from 4 bytes/dim to ~0.5 bytes/dim (plus a 4-byte scale), +/// roughly an 8x reduction with quality between RaBitQ and int8-sq. +/// +/// +/// The "TurboQUANT" name in the literature refers to several distinct schemes, +/// most including a rotation transform and SIMD-tuned distance kernels. This +/// implementation is the simpler 4-bit symmetric scalar-quantization core that +/// can be swapped for a more elaborate variant later without changing the +/// public surface. Cosine similarity benefits from the scale cancelling in +/// numerator and denominator; Euclidean distance reconstructs values as +/// code * scale before differencing. +/// +public sealed class TurboQuantEncoding : IVectorEncoding +{ + public const string EncodingId = "turboquant"; + + // Symmetric 4-bit range uses codes -7..+7. The -8 code is not produced by + // the encoder so the rounding behavior stays symmetric; sign-extending on + // read still treats 0x8 (= -8) correctly if encountered. + private const int MaxCode = 7; + + public static readonly TurboQuantEncoding Instance = new(); + + public string Id => EncodingId; + + public IEncodedVector Encode(float[] vector) + { + if (vector is null) throw new ArgumentNullException(nameof(vector)); + int d = vector.Length; + + float absMax = 0f; + for (int i = 0; i < d; i++) + { + float a = MathF.Abs(vector[i]); + if (a > absMax) absMax = a; + } + + float scale = absMax / MaxCode; + var codes = new sbyte[d]; + if (scale > 0f) + { + float inv = 1f / scale; + for (int i = 0; i < d; i++) + { + int q = (int)MathF.Round(vector[i] * inv); + if (q > MaxCode) q = MaxCode; + else if (q < -MaxCode) q = -MaxCode; + codes[i] = (sbyte)q; + } + } + + return new TurboQuantVector(scale, codes); + } + + public IEncodedVector LoadFromBytes(byte[] bytes, int dimensions) + { + if (bytes is null) throw new ArgumentNullException(nameof(bytes)); + int expected = sizeof(float) + (dimensions + 1) / 2; + if (bytes.Length != expected) + throw new ArgumentException( + $"Expected {expected} bytes for turboquant of {dimensions} dims, got {bytes.Length}.", + nameof(bytes)); + + float scale = BinaryPrimitives.ReadSingleLittleEndian(bytes.AsSpan(0, sizeof(float))); + var codes = new sbyte[dimensions]; + int payloadStart = sizeof(float); + for (int i = 0; i < dimensions; i++) + { + byte twoCodes = bytes[payloadStart + (i >> 1)]; + int nibble = ((i & 1) == 0) ? (twoCodes & 0x0F) : ((twoCodes >> 4) & 0x0F); + // Sign-extend 4-bit value to 8-bit signed. + if (nibble >= 8) nibble -= 16; + codes[i] = (sbyte)nibble; + } + return new TurboQuantVector(scale, codes); + } + + public float Compare(VectorComparison metric, float[] query, IEncodedVector encoded) + { + if (encoded is not TurboQuantVector tq) + throw new ArgumentException( + $"TurboQuantEncoding cannot compare against encoding '{encoded.EncodingId}'.", + nameof(encoded)); + if (query.Length != tq.Codes.Length) + throw new ArgumentException("Vectors must be of the same length."); + + return metric switch + { + VectorComparison.CosineSimilarity => CosineSimilarity(query, tq), + VectorComparison.EuclideanDistance => EuclideanDistance(query, tq), + _ => throw new ArgumentOutOfRangeException(nameof(metric), metric, null) + }; + } + + private static float CosineSimilarity(float[] query, TurboQuantVector stored) + { + // Identical reasoning to Int8 SQ: the per-vector scale appears in both + // numerator and the magnitude of the decoded vector, so it cancels. + float dot = 0f; + long codeSqSum = 0; + float qSqSum = 0f; + var codes = stored.Codes; + for (int i = 0; i < codes.Length; i++) + { + float qi = query[i]; + int ci = codes[i]; + dot += qi * ci; + codeSqSum += ci * ci; + qSqSum += qi * qi; + } + float magQuery = MathF.Sqrt(qSqSum); + float magCodes = MathF.Sqrt(codeSqSum); + if (magQuery == 0f || magCodes == 0f) return 0f; + return dot / (magQuery * magCodes); + } + + private static float EuclideanDistance(float[] query, TurboQuantVector stored) + { + float scale = stored.Scale; + var codes = stored.Codes; + float sum = 0f; + for (int i = 0; i < codes.Length; i++) + { + float d = query[i] - codes[i] * scale; + sum += d * d; + } + return MathF.Sqrt(sum); + } + + private sealed class TurboQuantVector : IEncodedVector + { + internal readonly float Scale; + internal readonly sbyte[] Codes; + + public TurboQuantVector(float scale, sbyte[] codes) + { + Scale = scale; + Codes = codes; + } + + public string EncodingId => TurboQuantEncoding.EncodingId; + + public int Dimensions => Codes.Length; + + public byte[] GetBytes() + { + int d = Codes.Length; + int payloadLen = (d + 1) / 2; + var bytes = new byte[sizeof(float) + payloadLen]; + BinaryPrimitives.WriteSingleLittleEndian(bytes.AsSpan(0, sizeof(float)), Scale); + + int payloadStart = sizeof(float); + for (int i = 0; i < d; i++) + { + int nibble = Codes[i] & 0x0F; // mask to 4 bits (two's complement preserved) + if ((i & 1) == 0) + { + bytes[payloadStart + (i >> 1)] = (byte)nibble; + } + else + { + bytes[payloadStart + (i >> 1)] |= (byte)(nibble << 4); + } + } + return bytes; + } + + public float[] Decode() + { + var values = new float[Codes.Length]; + for (int i = 0; i < Codes.Length; i++) + { + values[i] = Codes[i] * Scale; + } + return values; + } + } +} diff --git a/src/Build5Nines.SharpVector/VectorEncoding/VectorEncodingRegistry.cs b/src/Build5Nines.SharpVector/VectorEncoding/VectorEncodingRegistry.cs new file mode 100644 index 0000000..5bb0003 --- /dev/null +++ b/src/Build5Nines.SharpVector/VectorEncoding/VectorEncodingRegistry.cs @@ -0,0 +1,52 @@ +using System; +using System.Collections.Concurrent; + +namespace Build5Nines.SharpVector.VectorEncoding; + +/// +/// Lookup for instances by id. Used by the +/// persistence layer to rehydrate encoded vectors when a database is loaded +/// from disk. +/// +public static class VectorEncodingRegistry +{ + private static readonly ConcurrentDictionary _encodings; + + static VectorEncodingRegistry() + { + _encodings = new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase); + Register(RawFloat32Encoding.Instance); + Register(Int8ScalarQuantizationEncoding.Instance); + Register(RaBitQEncoding.Instance); + Register(TurboQuantEncoding.Instance); + } + + /// + /// Register a custom encoding. Re-registering an existing id replaces the entry. + /// + public static void Register(IVectorEncoding encoding) + { + if (encoding is null) throw new ArgumentNullException(nameof(encoding)); + if (string.IsNullOrEmpty(encoding.Id)) + throw new ArgumentException("Encoding must have a non-empty Id.", nameof(encoding)); + _encodings[encoding.Id] = encoding; + } + + /// + /// Resolve an encoding by id. Throws if the id is unknown. + /// + public static IVectorEncoding Get(string encodingId) + { + if (encodingId is null) throw new ArgumentNullException(nameof(encodingId)); + if (_encodings.TryGetValue(encodingId, out var enc)) return enc; + throw new KeyNotFoundException($"No vector encoding registered with id '{encodingId}'."); + } + + /// + /// Resolve an encoding by id without throwing. + /// + public static bool TryGet(string encodingId, out IVectorEncoding encoding) + { + return _encodings.TryGetValue(encodingId, out encoding!); + } +} diff --git a/src/Build5Nines.SharpVector/VectorTextItem.cs b/src/Build5Nines.SharpVector/VectorTextItem.cs index f3d03bf..0e2672e 100644 --- a/src/Build5Nines.SharpVector/VectorTextItem.cs +++ b/src/Build5Nines.SharpVector/VectorTextItem.cs @@ -1,3 +1,7 @@ +using System.Text.Json; +using System.Text.Json.Serialization; +using Build5Nines.SharpVector.VectorEncoding; + namespace Build5Nines.SharpVector; /// @@ -9,7 +13,26 @@ public interface IVectorTextItem { TDocument Text { get; set; } TMetadata? Metadata { get; set; } + + /// + /// The float vector representation. When the item is backed by a + /// compressed encoding this returns a decoded approximation; assigning + /// replaces the encoded vector with a fresh raw (lossless) encoding. + /// float[] Vector { get; set; } + + /// + /// The encoded form actually stored by the database. This is the + /// authoritative representation: is derived from it. + /// The default implementation adapts to/from via + /// raw float32 encoding so existing external implementations of this + /// interface keep compiling. + /// + IEncodedVector EncodedVector + { + get => RawFloat32Encoding.Instance.Encode(Vector); + set => Vector = value.Decode(); + } } /// @@ -24,18 +47,33 @@ public interface IVectorTextItem : IVectorTextItem /// /// /// +[JsonConverter(typeof(VectorTextItemJsonConverterFactory))] public class VectorTextItem : IVectorTextItem { public VectorTextItem(TDocument text, TMetadata? metadata, float[] vector) { Text = text; Metadata = metadata; - Vector = vector; + EncodedVector = RawFloat32Encoding.Instance.Encode(vector); + } + + public VectorTextItem(TDocument text, TMetadata? metadata, IEncodedVector encodedVector) + { + Text = text; + Metadata = metadata; + EncodedVector = encodedVector ?? throw new ArgumentNullException(nameof(encodedVector)); } - + public TDocument Text { get; set; } public TMetadata? Metadata { get; set; } - public float[] Vector { get; set; } + + public IEncodedVector EncodedVector { get; set; } + + public float[] Vector + { + get => EncodedVector.Decode(); + set => EncodedVector = RawFloat32Encoding.Instance.Encode(value); + } } /// @@ -47,4 +85,8 @@ public class VectorTextItem : VectorTextItem, IVec public VectorTextItem(string text, TMetadata? metadata, float[] vector) : base(text, metadata, vector) { } -} \ No newline at end of file + + public VectorTextItem(string text, TMetadata? metadata, IEncodedVector encodedVector) + : base(text, metadata, encodedVector) + { } +} diff --git a/src/Build5Nines.SharpVector/VectorTextItemJsonConverter.cs b/src/Build5Nines.SharpVector/VectorTextItemJsonConverter.cs new file mode 100644 index 0000000..ac3dc49 --- /dev/null +++ b/src/Build5Nines.SharpVector/VectorTextItemJsonConverter.cs @@ -0,0 +1,137 @@ +using System; +using System.Text.Json; +using System.Text.Json.Serialization; +using Build5Nines.SharpVector.VectorEncoding; + +namespace Build5Nines.SharpVector; + +/// +/// JsonConverter factory for . +/// Handles both the legacy on-disk shape (where vectors are written as a +/// plain float array under "Vector") and the new shape that carries an +/// explicit encoding tag plus base64 bytes. Raw-encoded vectors continue to +/// be written in the legacy shape so previously-saved databases remain +/// byte-identical. +/// +internal sealed class VectorTextItemJsonConverterFactory : JsonConverterFactory +{ + public override bool CanConvert(Type typeToConvert) + { + if (!typeToConvert.IsGenericType) return false; + var def = typeToConvert.GetGenericTypeDefinition(); + return def == typeof(VectorTextItem<,>); + } + + public override JsonConverter? CreateConverter(Type typeToConvert, JsonSerializerOptions options) + { + var args = typeToConvert.GetGenericArguments(); + var converterType = typeof(VectorTextItemJsonConverter<,>).MakeGenericType(args[0], args[1]); + return (JsonConverter?)Activator.CreateInstance(converterType); + } +} + +internal sealed class VectorTextItemJsonConverter + : JsonConverter> +{ + public override VectorTextItem Read( + ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType != JsonTokenType.StartObject) + throw new JsonException("Expected start of object for VectorTextItem."); + + TDocument? text = default; + TMetadata? metadata = default; + float[]? legacyVector = null; + string? encodingId = null; + int? dimensions = null; + byte[]? encodedBytes = null; + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) break; + if (reader.TokenType != JsonTokenType.PropertyName) + throw new JsonException("Expected property name in VectorTextItem."); + + string? propName = reader.GetString(); + reader.Read(); + + switch (propName) + { + case "Text": + case "text": + text = JsonSerializer.Deserialize(ref reader, options); + break; + case "Metadata": + case "metadata": + metadata = JsonSerializer.Deserialize(ref reader, options); + break; + case "Vector": + case "vector": + legacyVector = JsonSerializer.Deserialize(ref reader, options); + break; + case "EncodingId": + case "encodingId": + encodingId = reader.GetString(); + break; + case "Dimensions": + case "dimensions": + dimensions = reader.GetInt32(); + break; + case "EncodedBytes": + case "encodedBytes": + encodedBytes = reader.GetBytesFromBase64(); + break; + default: + reader.Skip(); + break; + } + } + + IEncodedVector encoded; + if (encodingId is not null && encodedBytes is not null && dimensions is not null) + { + var encoding = VectorEncodingRegistry.Get(encodingId); + encoded = encoding.LoadFromBytes(encodedBytes, dimensions.Value); + } + else if (legacyVector is not null) + { + encoded = RawFloat32Encoding.Instance.Encode(legacyVector); + } + else + { + // Empty/missing vector — preserve null-ish behavior with a zero-length raw encoding. + encoded = RawFloat32Encoding.Instance.Encode(Array.Empty()); + } + + return new VectorTextItem(text!, metadata, encoded); + } + + public override void Write( + Utf8JsonWriter writer, VectorTextItem value, JsonSerializerOptions options) + { + writer.WriteStartObject(); + + writer.WritePropertyName("Text"); + JsonSerializer.Serialize(writer, value.Text, options); + + writer.WritePropertyName("Metadata"); + JsonSerializer.Serialize(writer, value.Metadata, options); + + if (value.EncodedVector.EncodingId == RawFloat32Encoding.EncodingId) + { + // Preserve the legacy on-disk shape exactly: a plain float array + // under "Vector". This means files written by raw-encoded databases + // match what older versions produced, byte for byte. + writer.WritePropertyName("Vector"); + JsonSerializer.Serialize(writer, value.Vector, options); + } + else + { + writer.WriteString("EncodingId", value.EncodedVector.EncodingId); + writer.WriteNumber("Dimensions", value.EncodedVector.Dimensions); + writer.WriteBase64String("EncodedBytes", value.EncodedVector.GetBytes()); + } + + writer.WriteEndObject(); + } +} diff --git a/src/SharpVectorTest/VectorEncoding/EncodedDatabaseTests.cs b/src/SharpVectorTest/VectorEncoding/EncodedDatabaseTests.cs new file mode 100644 index 0000000..8f04d4d --- /dev/null +++ b/src/SharpVectorTest/VectorEncoding/EncodedDatabaseTests.cs @@ -0,0 +1,182 @@ +using Build5Nines.SharpVector; +using Build5Nines.SharpVector.VectorEncoding; + +namespace SharpVectorTest.VectorEncoding; + +[TestClass] +public class EncodedDatabaseTests +{ + private const string SampleText = + "The Lion King is a 1994 Disney animated film about a young lion cub named Simba who is the heir to the throne of an African savanna."; + + [TestMethod] + public void DefaultDatabaseUsesRawEncoding() + { + var vdb = new BasicMemoryVectorDatabase(); + Assert.AreEqual(RawFloat32Encoding.EncodingId, vdb.VectorEncoding.Id); + } + + [TestMethod] + public void ConstructorAcceptsEncoding() + { + var vdb = new BasicMemoryVectorDatabase(Int8ScalarQuantizationEncoding.Instance); + Assert.AreEqual(Int8ScalarQuantizationEncoding.EncodingId, vdb.VectorEncoding.Id); + } + + [TestMethod] + public void Int8Database_StoresEncodedVectorAndStillFindsResult() + { + var vdb = new BasicMemoryVectorDatabase(Int8ScalarQuantizationEncoding.Instance); + vdb.AddText(SampleText, "meta"); + + var stored = vdb.GetText(1); + Assert.AreEqual(Int8ScalarQuantizationEncoding.EncodingId, stored.EncodedVector.EncodingId); + + var results = vdb.Search("Lion King"); + Assert.IsTrue(results.Texts.Any(t => t.Text.Contains("Lion King"))); + } + + [TestMethod] + public async Task SaveAndLoad_PreservesRawEncoding() + { + var vdb = new BasicMemoryVectorDatabase(); + vdb.AddText(SampleText, "meta"); + + using var ms = new MemoryStream(); + await vdb.SerializeToBinaryStreamAsync(ms); + ms.Position = 0; + + var reloaded = new BasicMemoryVectorDatabase(); + await reloaded.DeserializeFromBinaryStreamAsync(ms); + + Assert.AreEqual(RawFloat32Encoding.EncodingId, reloaded.VectorEncoding.Id); + Assert.AreEqual(SampleText, reloaded.GetText(1).Text); + + var results = reloaded.Search("Lion King"); + Assert.IsTrue(results.Texts.Any(t => t.Text.Contains("Lion King"))); + } + + [TestMethod] + public async Task SaveAndLoad_PreservesInt8Encoding() + { + var vdb = new BasicMemoryVectorDatabase(Int8ScalarQuantizationEncoding.Instance); + vdb.AddText(SampleText, "meta"); + + using var ms = new MemoryStream(); + await vdb.SerializeToBinaryStreamAsync(ms); + ms.Position = 0; + + // Construct the reload-target with raw; the file's encoding should win. + var reloaded = new BasicMemoryVectorDatabase(); + await reloaded.DeserializeFromBinaryStreamAsync(ms); + + Assert.AreEqual(Int8ScalarQuantizationEncoding.EncodingId, reloaded.VectorEncoding.Id); + + var stored = reloaded.GetText(1); + Assert.AreEqual(Int8ScalarQuantizationEncoding.EncodingId, stored.EncodedVector.EncodingId); + Assert.AreEqual(SampleText, stored.Text); + } + + [TestMethod] + public async Task RawSavedFile_DoesNotContainEncodingIdField() + { + // To preserve byte compatibility with files written by older versions + // of the library, a raw-encoded database must not emit the new + // VectorEncodingId property into database.json. + var vdb = new BasicMemoryVectorDatabase(); + vdb.AddText(SampleText, "meta"); + + using var ms = new MemoryStream(); + await vdb.SerializeToBinaryStreamAsync(ms); + ms.Position = 0; + + using var archive = new System.IO.Compression.ZipArchive(ms, System.IO.Compression.ZipArchiveMode.Read); + var dbEntry = archive.GetEntry("database.json")!; + using var dbStream = dbEntry.Open(); + using var reader = new StreamReader(dbStream); + var json = reader.ReadToEnd(); + + StringAssert.Contains(json, "\"ClassType\""); + Assert.IsFalse(json.Contains("VectorEncodingId"), + $"Raw-encoded database.json must omit VectorEncodingId. Actual: {json}"); + } + + [TestMethod] + public async Task Int8SavedFile_RecordsEncodingId() + { + var vdb = new BasicMemoryVectorDatabase(Int8ScalarQuantizationEncoding.Instance); + vdb.AddText(SampleText, "meta"); + + using var ms = new MemoryStream(); + await vdb.SerializeToBinaryStreamAsync(ms); + ms.Position = 0; + + using var archive = new System.IO.Compression.ZipArchive(ms, System.IO.Compression.ZipArchiveMode.Read); + var dbEntry = archive.GetEntry("database.json")!; + using var dbStream = dbEntry.Open(); + using var reader = new StreamReader(dbStream); + var json = reader.ReadToEnd(); + + StringAssert.Contains(json, "\"VectorEncodingId\":\"" + Int8ScalarQuantizationEncoding.EncodingId + "\""); + } + + [TestMethod] + public void RaBitQDatabase_StoresEncodedVectorAndStillFindsResult() + { + var vdb = new BasicMemoryVectorDatabase(RaBitQEncoding.Instance); + vdb.AddText(SampleText, "meta"); + + var stored = vdb.GetText(1); + Assert.AreEqual(RaBitQEncoding.EncodingId, stored.EncodedVector.EncodingId); + + var results = vdb.Search("Lion King"); + Assert.IsTrue(results.Texts.Any(t => t.Text.Contains("Lion King"))); + } + + [TestMethod] + public async Task SaveAndLoad_PreservesRaBitQEncoding() + { + var vdb = new BasicMemoryVectorDatabase(RaBitQEncoding.Instance); + vdb.AddText(SampleText, "meta"); + + using var ms = new MemoryStream(); + await vdb.SerializeToBinaryStreamAsync(ms); + ms.Position = 0; + + var reloaded = new BasicMemoryVectorDatabase(); + await reloaded.DeserializeFromBinaryStreamAsync(ms); + + Assert.AreEqual(RaBitQEncoding.EncodingId, reloaded.VectorEncoding.Id); + Assert.AreEqual(SampleText, reloaded.GetText(1).Text); + } + + [TestMethod] + public void TurboQuantDatabase_StoresEncodedVectorAndStillFindsResult() + { + var vdb = new BasicMemoryVectorDatabase(TurboQuantEncoding.Instance); + vdb.AddText(SampleText, "meta"); + + var stored = vdb.GetText(1); + Assert.AreEqual(TurboQuantEncoding.EncodingId, stored.EncodedVector.EncodingId); + + var results = vdb.Search("Lion King"); + Assert.IsTrue(results.Texts.Any(t => t.Text.Contains("Lion King"))); + } + + [TestMethod] + public async Task SaveAndLoad_PreservesTurboQuantEncoding() + { + var vdb = new BasicMemoryVectorDatabase(TurboQuantEncoding.Instance); + vdb.AddText(SampleText, "meta"); + + using var ms = new MemoryStream(); + await vdb.SerializeToBinaryStreamAsync(ms); + ms.Position = 0; + + var reloaded = new BasicMemoryVectorDatabase(); + await reloaded.DeserializeFromBinaryStreamAsync(ms); + + Assert.AreEqual(TurboQuantEncoding.EncodingId, reloaded.VectorEncoding.Id); + Assert.AreEqual(SampleText, reloaded.GetText(1).Text); + } +} diff --git a/src/SharpVectorTest/VectorEncoding/EncodingTests.cs b/src/SharpVectorTest/VectorEncoding/EncodingTests.cs new file mode 100644 index 0000000..8c35a30 --- /dev/null +++ b/src/SharpVectorTest/VectorEncoding/EncodingTests.cs @@ -0,0 +1,180 @@ +using Build5Nines.SharpVector; +using Build5Nines.SharpVector.VectorEncoding; + +namespace SharpVectorTest.VectorEncoding; + +[TestClass] +public class EncodingTests +{ + private static float[] SampleVector(int dims, int seed) + { + var rng = new Random(seed); + var v = new float[dims]; + for (int i = 0; i < dims; i++) v[i] = (float)(rng.NextDouble() * 2.0 - 1.0); + return v; + } + + [TestMethod] + public void RawFloat32_RoundTrip_IsLossless() + { + var original = SampleVector(128, 1); + var encoded = RawFloat32Encoding.Instance.Encode(original); + Assert.AreEqual(RawFloat32Encoding.EncodingId, encoded.EncodingId); + Assert.AreEqual(128, encoded.Dimensions); + + var decoded = encoded.Decode(); + CollectionAssert.AreEqual(original, decoded); + } + + [TestMethod] + public void RawFloat32_BytesRoundTrip_IsLossless() + { + var original = SampleVector(64, 2); + var encoded = RawFloat32Encoding.Instance.Encode(original); + var bytes = encoded.GetBytes(); + + var rehydrated = RawFloat32Encoding.Instance.LoadFromBytes(bytes, 64); + CollectionAssert.AreEqual(original, rehydrated.Decode()); + } + + [TestMethod] + public void Int8Sq_RoundTrip_PreservesCosineSimilarityClosely() + { + var a = SampleVector(384, 3); + var b = SampleVector(384, 4); + + var rawEncA = RawFloat32Encoding.Instance.Encode(a); + var rawCos = RawFloat32Encoding.Instance.Compare(VectorComparison.CosineSimilarity, b, rawEncA); + + var encA = Int8ScalarQuantizationEncoding.Instance.Encode(a); + var cosViaInt8 = Int8ScalarQuantizationEncoding.Instance.Compare( + VectorComparison.CosineSimilarity, b, encA); + + // int8-sq introduces small rounding error; for random vectors this is well under 1%. + Assert.IsTrue(Math.Abs(rawCos - cosViaInt8) < 0.01f, + $"Expected cosine within 0.01 of raw value, got |{rawCos} - {cosViaInt8}| = {Math.Abs(rawCos - cosViaInt8)}"); + } + + [TestMethod] + public void Int8Sq_BytesRoundTrip_RestoresEncodedForm() + { + var original = SampleVector(256, 5); + var encoded = Int8ScalarQuantizationEncoding.Instance.Encode(original); + var bytes = encoded.GetBytes(); + + var rehydrated = Int8ScalarQuantizationEncoding.Instance.LoadFromBytes(bytes, 256); + + // Compare cosine sim against an arbitrary query: should be identical + // since the encoded form is fully recovered. + var query = SampleVector(256, 6); + float a = Int8ScalarQuantizationEncoding.Instance.Compare(VectorComparison.CosineSimilarity, query, encoded); + float b = Int8ScalarQuantizationEncoding.Instance.Compare(VectorComparison.CosineSimilarity, query, rehydrated); + Assert.AreEqual(a, b); + } + + [TestMethod] + public void Registry_ResolvesBuiltinEncodings() + { + Assert.AreSame(RawFloat32Encoding.Instance, VectorEncodingRegistry.Get(RawFloat32Encoding.EncodingId)); + Assert.AreSame(Int8ScalarQuantizationEncoding.Instance, VectorEncodingRegistry.Get(Int8ScalarQuantizationEncoding.EncodingId)); + } + + [TestMethod] + public void Registry_UnknownIdThrows() + { + Assert.ThrowsException(() => VectorEncodingRegistry.Get("no-such-encoding")); + } + + [TestMethod] + public void RaBitQ_BytesRoundTrip_RestoresEncodedForm() + { + var original = SampleVector(256, 7); + var encoded = RaBitQEncoding.Instance.Encode(original); + var bytes = encoded.GetBytes(); + + var rehydrated = RaBitQEncoding.Instance.LoadFromBytes(bytes, 256); + + var query = SampleVector(256, 8); + float a = RaBitQEncoding.Instance.Compare(VectorComparison.CosineSimilarity, query, encoded); + float b = RaBitQEncoding.Instance.Compare(VectorComparison.CosineSimilarity, query, rehydrated); + Assert.AreEqual(a, b); + } + + [TestMethod] + public void RaBitQ_RankingMatchesRawForRandomVectors() + { + // RaBitQ is a coarse approximation, so don't assert numeric closeness; + // instead assert it can still rank a clearly-similar pair above a + // clearly-dissimilar pair. + var query = SampleVector(512, 10); + var similar = (float[])query.Clone(); + for (int i = 0; i < similar.Length; i++) similar[i] += 0.05f * (i % 3 - 1); + var different = SampleVector(512, 11); + + var encSimilar = RaBitQEncoding.Instance.Encode(similar); + var encDifferent = RaBitQEncoding.Instance.Encode(different); + + float simSimilar = RaBitQEncoding.Instance.Compare(VectorComparison.CosineSimilarity, query, encSimilar); + float simDifferent = RaBitQEncoding.Instance.Compare(VectorComparison.CosineSimilarity, query, encDifferent); + + Assert.IsTrue(simSimilar > simDifferent, + $"RaBitQ similarity ranking incorrect: similar={simSimilar}, different={simDifferent}"); + } + + [TestMethod] + public void RaBitQ_ProducesExpectedStorageSize() + { + // 8 bytes of scalar correction + ceil(D/8) bytes of sign bits. + var encoded = RaBitQEncoding.Instance.Encode(SampleVector(384, 12)); + Assert.AreEqual(8 + 48, encoded.GetBytes().Length); + } + + [TestMethod] + public void TurboQuant_RoundTrip_PreservesCosineSimilarityClosely() + { + var a = SampleVector(384, 13); + var b = SampleVector(384, 14); + + var rawEncA = RawFloat32Encoding.Instance.Encode(a); + var rawCos = RawFloat32Encoding.Instance.Compare(VectorComparison.CosineSimilarity, b, rawEncA); + + var encA = TurboQuantEncoding.Instance.Encode(a); + var cosViaTurbo = TurboQuantEncoding.Instance.Compare(VectorComparison.CosineSimilarity, b, encA); + + // 4-bit SQ is coarser than int8; allow up to 3% deviation for random vectors. + Assert.IsTrue(Math.Abs(rawCos - cosViaTurbo) < 0.03f, + $"Expected cosine within 0.03 of raw value, got |{rawCos} - {cosViaTurbo}| = {Math.Abs(rawCos - cosViaTurbo)}"); + } + + [TestMethod] + public void TurboQuant_BytesRoundTrip_RestoresEncodedForm() + { + var original = SampleVector(256, 15); + var encoded = TurboQuantEncoding.Instance.Encode(original); + var bytes = encoded.GetBytes(); + + var rehydrated = TurboQuantEncoding.Instance.LoadFromBytes(bytes, 256); + + var query = SampleVector(256, 16); + float a = TurboQuantEncoding.Instance.Compare(VectorComparison.CosineSimilarity, query, encoded); + float b = TurboQuantEncoding.Instance.Compare(VectorComparison.CosineSimilarity, query, rehydrated); + Assert.AreEqual(a, b); + } + + [TestMethod] + public void TurboQuant_HandlesOddDimensions() + { + // Odd dimension count exercises the nibble-packing tail case. + var original = SampleVector(127, 17); + var encoded = TurboQuantEncoding.Instance.Encode(original); + var bytes = encoded.GetBytes(); + // 4 bytes scale + ceil(127/2) = 64 bytes payload + Assert.AreEqual(4 + 64, bytes.Length); + + var rehydrated = TurboQuantEncoding.Instance.LoadFromBytes(bytes, 127); + var query = SampleVector(127, 18); + float a = TurboQuantEncoding.Instance.Compare(VectorComparison.CosineSimilarity, query, encoded); + float b = TurboQuantEncoding.Instance.Compare(VectorComparison.CosineSimilarity, query, rehydrated); + Assert.AreEqual(a, b); + } +}