diff --git a/DuckDB.NET.Data/DataChunk/Reader/EnumVectorDataReader.cs b/DuckDB.NET.Data/DataChunk/Reader/EnumVectorDataReader.cs index e10f87f..79d12d4 100644 --- a/DuckDB.NET.Data/DataChunk/Reader/EnumVectorDataReader.cs +++ b/DuckDB.NET.Data/DataChunk/Reader/EnumVectorDataReader.cs @@ -46,14 +46,14 @@ T ToEnumOrString(TSource enumValue) where TSource: IBinaryNumber(ref enumValue); } } @@ -72,12 +72,12 @@ internal override object GetValue(ulong offset, Type targetType) if (targetType == typeof(string)) { - if (!cachedNames.TryGetValue(enumValue, out var name)) - { - cachedNames[enumValue] = name = NativeMethods.LogicalType.DuckDBEnumDictionaryValue(logicalType, enumValue); - } + return GetEnumName(enumValue); + } - return name; + if (targetType.IsEnum) + { + return ConvertToTargetEnum(enumValue, targetType); } return Enum.ToObject(targetType, enumValue); @@ -91,4 +91,25 @@ public override void Dispose() logicalType.Dispose(); base.Dispose(); } + + private string GetEnumName(long enumValue) + { + if (!cachedNames.TryGetValue(enumValue, out var name)) + { + cachedNames[enumValue] = name = NativeMethods.LogicalType.DuckDBEnumDictionaryValue(logicalType, enumValue); + } + + return name; + } + + private object ConvertToTargetEnum(long enumValue, Type targetType) + { + var enumName = GetEnumName(enumValue); + if (Enum.TryParse(targetType, enumName, true, out var parsedEnum)) + { + return parsedEnum; + } + + throw new InvalidCastException($"Cannot convert DuckDB enum value \"{enumName}\" to {targetType.Name}."); + } } \ No newline at end of file diff --git a/DuckDB.NET.Data/DataChunk/Writer/EnumVectorDataWriter.cs b/DuckDB.NET.Data/DataChunk/Writer/EnumVectorDataWriter.cs index 787b5d4..45544a8 100644 --- a/DuckDB.NET.Data/DataChunk/Writer/EnumVectorDataWriter.cs +++ b/DuckDB.NET.Data/DataChunk/Writer/EnumVectorDataWriter.cs @@ -6,29 +6,14 @@ internal sealed unsafe class EnumVectorDataWriter(IntPtr vector, void* vectorDat private readonly uint enumDictionarySize = NativeMethods.LogicalType.DuckDBEnumDictionarySize(logicalType); - private readonly Dictionary enumValues = []; + private readonly Dictionary enumValues = new(StringComparer.OrdinalIgnoreCase); internal override bool AppendString(string value, ulong rowIndex) { - if (enumValues.Count == 0) - { - for (uint index = 0; index < enumDictionarySize; index++) - { - var enumValueName = NativeMethods.LogicalType.DuckDBEnumDictionaryValue(logicalType, index); - enumValues.Add(enumValueName, index); - } - } - + EnsureEnumValuesInitialized(); if (enumValues.TryGetValue(value, out var enumValue)) { - // The following casts to byte and ushort are safe because we ensure in the constructor that the value enumDictionarySize is not too high. - return enumType switch - { - DuckDBType.UnsignedTinyInt => AppendValueInternal((byte)enumValue, rowIndex), - DuckDBType.UnsignedSmallInt => AppendValueInternal((ushort)enumValue, rowIndex), - DuckDBType.UnsignedInteger => AppendValueInternal(enumValue, rowIndex), - _ => throw new InvalidOperationException($"Failed to write Enum column because the internal enum type must be utinyint, usmallint, or uinteger."), - }; + return AppendEnumDictionaryIndex(enumValue, rowIndex); } throw new InvalidOperationException($"Failed to write Enum column because the value \"{value}\" is not valid."); @@ -36,36 +21,49 @@ internal override bool AppendString(string value, ulong rowIndex) internal override bool AppendEnum(TEnum value, ulong rowIndex) { - var enumValue = ConvertEnumValueToUInt64(value); - if (enumValue < enumDictionarySize) + var enumValueType = value.GetType(); + if (enumValueType.IsDefined(typeof(FlagsAttribute), false)) { - // The following casts to byte, ushort and uint are safe because we ensure in the constructor that the value enumDictionarySize is not too high. - return enumType switch + throw new InvalidOperationException("Failed to write Enum column because [Flags] enums are not supported."); + } + + var enumName = Enum.GetName(enumValueType, value); + if (enumName is not null) + { + EnsureEnumValuesInitialized(); + if (enumValues.TryGetValue(enumName, out var enumValue)) { - DuckDBType.UnsignedTinyInt => AppendValueInternal((byte)enumValue, rowIndex), - DuckDBType.UnsignedSmallInt => AppendValueInternal((ushort)enumValue, rowIndex), - DuckDBType.UnsignedInteger => AppendValueInternal((uint)enumValue, rowIndex), - _ => throw new InvalidOperationException($"Failed to write Enum column because the internal enum type must be utinyint, usmallint, or uinteger."), - }; + return AppendEnumDictionaryIndex(enumValue, rowIndex); + } } - throw new InvalidOperationException($"Failed to write Enum column because the value is outside the range (0-{enumDictionarySize - 1})."); + throw new InvalidOperationException($"Failed to write Enum column because the value \"{value}\" is not valid."); } - private static ulong ConvertEnumValueToUInt64(TEnum value) where TEnum : Enum + private bool AppendEnumDictionaryIndex(ulong dictionaryIndex, ulong rowIndex) { - return value.GetTypeCode() switch + // The following casts to byte and ushort are safe because we ensure in the constructor that the enumDictionarySize is not too high. + return enumType switch { - TypeCode.SByte => (ulong)Convert.ToSByte(value), - TypeCode.Byte => Convert.ToByte(value), - TypeCode.Int16 => (ulong)Convert.ToInt16(value), - TypeCode.UInt16 => Convert.ToUInt16(value), - TypeCode.Int32 => (ulong)Convert.ToInt32(value), - TypeCode.UInt32 => Convert.ToUInt32(value), - TypeCode.Int64 => (ulong)Convert.ToInt64(value), - TypeCode.UInt64 => Convert.ToUInt64(value), - _ => throw new InvalidOperationException($"Failed to convert the enum value {value} to ulong."), + DuckDBType.UnsignedTinyInt => AppendValueInternal((byte)dictionaryIndex, rowIndex), + DuckDBType.UnsignedSmallInt => AppendValueInternal((ushort)dictionaryIndex, rowIndex), + DuckDBType.UnsignedInteger => AppendValueInternal((uint)dictionaryIndex, rowIndex), + _ => throw new InvalidOperationException("Failed to write Enum column because the internal enum type must be utinyint, usmallint, or uinteger."), }; } + private void EnsureEnumValuesInitialized() + { + if (enumValues.Count != 0) + { + return; + } + + for (uint index = 0; index < enumDictionarySize; index++) + { + var enumValueName = NativeMethods.LogicalType.DuckDBEnumDictionaryValue(logicalType, index); + enumValues.Add(enumValueName, index); + } + } + } diff --git a/DuckDB.NET.Test/DuckDBManagedAppenderTests.cs b/DuckDB.NET.Test/DuckDBManagedAppenderTests.cs index e81d861..621cc31 100644 --- a/DuckDB.NET.Test/DuckDBManagedAppenderTests.cs +++ b/DuckDB.NET.Test/DuckDBManagedAppenderTests.cs @@ -411,6 +411,54 @@ public void EnumValues() result.Item9.Should().Be(TestEnum3.Test6699); } + [Fact] + public void EnumValuesWithNonConsecutiveUnderlyingValues() + { + var enumLabelsSql = string.Join(", ", Enum.GetNames().Select(name => $"'{name}'")); + Command.CommandText = $"CREATE TYPE non_consecutive_test_enum AS ENUM ({enumLabelsSql});"; + Command.ExecuteNonQuery(); + + Command.CommandText = "CREATE TABLE managedAppenderNonConsecutiveEnum(a non_consecutive_test_enum, b non_consecutive_test_enum, c non_consecutive_test_enum);"; + Command.ExecuteNonQuery(); + + using (var appender = Connection.CreateAppender("managedAppenderNonConsecutiveEnum")) + { + appender + .CreateRow() + .AppendValue(NonConsecutiveTestEnum.Happy) + .AppendValue(NonConsecutiveTestEnum.Sad) + .AppendValue(NonConsecutiveTestEnum.Neutral) + .EndRow(); + } + + Command.CommandText = "SELECT a, b, c FROM managedAppenderNonConsecutiveEnum"; + using var reader = Command.ExecuteReader(); + reader.Read(); + reader.GetFieldValue(0).Should().Be(NonConsecutiveTestEnum.Happy); + reader.GetFieldValue(1).Should().Be(nameof(NonConsecutiveTestEnum.Sad)); + reader.GetFieldValue(2).Should().Be(NonConsecutiveTestEnum.Neutral); + } + + [Fact] + public void FlagsEnumValuesThrowException() + { + var enumLabelsSql = string.Join(", ", Enum.GetNames().Select(name => $"'{name}'")); + Command.CommandText = $"CREATE TYPE flags_test_enum AS ENUM ({enumLabelsSql});"; + Command.ExecuteNonQuery(); + + Command.CommandText = "CREATE TABLE managedAppenderFlagsEnum(a flags_test_enum);"; + Command.ExecuteNonQuery(); + + Connection.Invoking(dbConnection => + { + using var appender = dbConnection.CreateAppender("managedAppenderFlagsEnum"); + appender + .CreateRow() + .AppendValue(FlagsTestEnum.Happy) + .EndRow(); + }).Should().Throw().Where(exception => exception.Message.Contains("Flags")); + } + [Fact] public void IncompleteRowThrowsException() { @@ -851,4 +899,19 @@ private enum EnumNotValidValueTestEnum { NotValid = 12345, } + + private enum NonConsecutiveTestEnum : byte + { + Happy = 1, + Sad = 2, + Neutral = 4, + } + + [Flags] + private enum FlagsTestEnum : byte + { + Happy = 1, + Sad = 2, + Neutral = 4, + } } \ No newline at end of file