diff --git a/csharp/src/Apache.Arrow.Adbc/AdbcStatement.cs b/csharp/src/Apache.Arrow.Adbc/AdbcStatement.cs index 122375c11d..1ff714b3dd 100644 --- a/csharp/src/Apache.Arrow.Adbc/AdbcStatement.cs +++ b/csharp/src/Apache.Arrow.Adbc/AdbcStatement.cs @@ -145,6 +145,73 @@ public virtual void Dispose() /// /// The index in the array to get the value from. /// - public abstract object GetValue(IArrowArray arrowArray, Field field, int index); + public virtual object GetValue(IArrowArray arrowArray, Field field, int index) + { + if (arrowArray == null) throw new ArgumentNullException(nameof(arrowArray)); + if (field == null) throw new ArgumentNullException(nameof(field)); + if (index < 0) throw new ArgumentOutOfRangeException(nameof(index)); + + switch (arrowArray) + { + case BooleanArray booleanArray: + return booleanArray.GetValue(index); + case Date32Array date32Array: + return date32Array.GetDateTime(index); + case Date64Array date64Array: + return date64Array.GetDateTime(index); + case Decimal128Array decimal128Array: + return decimal128Array.GetSqlDecimal(index); + case Decimal256Array decimal256Array: + return decimal256Array.GetString(index); + case DoubleArray doubleArray: + return doubleArray.GetValue(index); + case FloatArray floatArray: + return floatArray.GetValue(index); +#if NET5_0_OR_GREATER + case PrimitiveArray halfFloatArray: + return halfFloatArray.GetValue(index); +#endif + case Int8Array int8Array: + return int8Array.GetValue(index); + case Int16Array int16Array: + return int16Array.GetValue(index); + case Int32Array int32Array: + return int32Array.GetValue(index); + case Int64Array int64Array: + return int64Array.GetValue(index); + case StringArray stringArray: + return stringArray.GetString(index); + case Time32Array time32Array: + return time32Array.GetValue(index); + case Time64Array time64Array: + return time64Array.GetValue(index); + case TimestampArray timestampArray: + DateTimeOffset dateTimeOffset = timestampArray.GetTimestamp(index).Value; + return dateTimeOffset; + case UInt8Array uInt8Array: + return uInt8Array.GetValue(index); + case UInt16Array uInt16Array: + return uInt16Array.GetValue(index); + case UInt32Array uInt32Array: + return uInt32Array.GetValue(index); + case UInt64Array uInt64Array: + return uInt64Array.GetValue(index); + + case BinaryArray binaryArray: + if (!binaryArray.IsNull(index)) + return binaryArray.GetBytes(index).ToArray(); + + return null; + + // not covered: + // -- struct array + // -- dictionary array + // -- fixed size binary + // -- list array + // -- union array + } + + return null; + } } } diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs index 914a7f224d..c621eb6ecd 100644 --- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs +++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs @@ -17,10 +17,10 @@ using System; using System.Collections.Generic; +using System.Data; using System.IO; using System.Linq; using System.Runtime.InteropServices; -using System.Text.RegularExpressions; using Apache.Arrow.C; using Apache.Arrow.Ipc; @@ -328,158 +328,6 @@ public override unsafe UpdateResult ExecuteUpdate() return new UpdateResult(rows); } } - - public override object GetValue(IArrowArray arrowArray, Field field, int index) - { - if (arrowArray is BooleanArray) - { - return ((BooleanArray)arrowArray).GetValue(index); - } - else if (arrowArray is Date32Array) - { - Date32Array date32Array = (Date32Array)arrowArray; - - return date32Array.GetDateTime(index); - } - else if (arrowArray is Date64Array) - { - Date64Array date64Array = (Date64Array)arrowArray; - - return date64Array.GetDateTime(index); - } - else if (arrowArray is Decimal128Array) - { - try - { - // the value may be decimal.max - // then Arrow throws an exception - // no good way to check prior to - return ((Decimal128Array)arrowArray).GetValue(index); - } - catch (OverflowException oex) - { - return ParseDecimalValueFromOverflowException(oex); - } - } - else if (arrowArray is Decimal256Array) - { - try - { - return ((Decimal256Array)arrowArray).GetValue(index); - } - catch (OverflowException oex) - { - return ParseDecimalValueFromOverflowException(oex); - } - } - else if (arrowArray is DoubleArray) - { - return ((DoubleArray)arrowArray).Values[index]; - } - else if (arrowArray is FloatArray) - { - return ((FloatArray)arrowArray).GetValue(index); - } -#if NET5_0_OR_GREATER - else if (arrowArray is PrimitiveArray) - { - // TODO: HalfFloatArray not present in current library - - return ((PrimitiveArray)arrowArray).GetValue(index); - } -#endif - else if (arrowArray is Int8Array) - { - Int8Array array = (Int8Array)arrowArray; - return array.GetValue(index); - } - else if (arrowArray is Int16Array) - { - return ((Int16Array)arrowArray).Values[index]; - } - else if (arrowArray is Int32Array) - { - return ((Int32Array)arrowArray).Values[index]; - } - else if (arrowArray is Int64Array) - { - Int64Array array = (Int64Array)arrowArray; - - return array.GetValue(index); - } - else if (arrowArray is StringArray) - { - return ((StringArray)arrowArray).GetString(index); - } - else if (arrowArray is Time32Array) - { - return ((Time32Array)arrowArray).GetValue(index); - } - else if (arrowArray is Time64Array) - { - return ((Time64Array)arrowArray).GetValue(index); - } - else if (arrowArray is TimestampArray) - { - TimestampArray timestampArray = (TimestampArray)arrowArray; - DateTimeOffset dateTimeOffset = timestampArray.GetTimestamp(index).Value; - return dateTimeOffset; - } - else if (arrowArray is UInt8Array) - { - return ((UInt8Array)arrowArray).GetValue(index); - } - else if (arrowArray is UInt16Array) - { - return ((UInt16Array)arrowArray).GetValue(index); - } - else if (arrowArray is UInt32Array) - { - return ((UInt32Array)arrowArray).GetValue(index); - } - else if (arrowArray is UInt64Array) - { - return ((UInt64Array)arrowArray).GetValue(index); - } - else if (arrowArray is BinaryArray) - { - ReadOnlySpan bytes = ((BinaryArray)arrowArray).GetBytes(index); - - if (bytes != null) - return bytes.ToArray(); - } - - // not covered: - // -- struct array - // -- dictionary array - // -- fixed size binary - // -- list array - // -- union array - - return null; - } - - private string ParseDecimalValueFromOverflowException(OverflowException oex) - { - if (oex == null) - throw new ArgumentNullException(nameof(oex)); - - // any decimal value, positive or negative, with or without a decimal in place - Regex regex = new Regex(" -?\\d*\\.?\\d* "); - - var matches = regex.Matches(oex.Message); - - foreach (Match match in matches) - { - string value = match.Value; - - if (!string.IsNullOrEmpty(value)) - return value; - } - - throw oex; - } - } /// diff --git a/csharp/src/Client/AdbcCommand.cs b/csharp/src/Client/AdbcCommand.cs index 585d50454f..2fcf6c302d 100644 --- a/csharp/src/Client/AdbcCommand.cs +++ b/csharp/src/Client/AdbcCommand.cs @@ -50,6 +50,7 @@ public AdbcCommand(AdbcStatement adbcStatement, AdbcConnection adbcConnection) : this.adbcStatement = adbcStatement; this.DbConnection = adbcConnection; + this.DecimalBehavior = adbcConnection.DecimalBehavior; } /// @@ -69,6 +70,7 @@ public AdbcCommand(string query, AdbcConnection adbcConnection) : base() this.CommandText = query; this.DbConnection = adbcConnection; + this.DecimalBehavior = adbcConnection.DecimalBehavior; } /// @@ -77,6 +79,8 @@ public AdbcCommand(string query, AdbcConnection adbcConnection) : base() /// public AdbcStatement AdbcStatement => this.adbcStatement; + public DecimalBehavior DecimalBehavior { get; set; } + public override string CommandText { get => this.adbcStatement.SqlQuery; @@ -170,7 +174,7 @@ protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) case CommandBehavior.SchemaOnly: // The schema is not known until a read happens case CommandBehavior.Default: QueryResult result = this.ExecuteQuery(); - return new AdbcDataReader(this, result); + return new AdbcDataReader(this, result, this.DecimalBehavior); default: throw new InvalidOperationException($"{behavior} is not supported with this provider"); diff --git a/csharp/src/Client/AdbcConnection.cs b/csharp/src/Client/AdbcConnection.cs index 6476f1b9ab..598bc66f4c 100644 --- a/csharp/src/Client/AdbcConnection.cs +++ b/csharp/src/Client/AdbcConnection.cs @@ -41,6 +41,7 @@ public sealed class AdbcConnection : DbConnection public AdbcConnection() { this.AdbcDriver = null; + this.DecimalBehavior = DecimalBehavior.UseSqlDecimal; this.adbcConnectionParameters = new Dictionary(); this.adbcConnectionOptions = new Dictionary(); } @@ -55,7 +56,7 @@ public AdbcConnection(string connectionString) : this() } /// - /// Overloaded. Intializes an . + /// Overloaded. Initializes an . /// /// /// The to use for connecting. This value @@ -67,7 +68,7 @@ public AdbcConnection(AdbcDriver adbcDriver) : this() } /// - /// Overloaded. Intializes an . + /// Overloaded. Initializes an . /// /// /// The to use for connecting. This value @@ -120,6 +121,11 @@ internal AdbcStatement AdbcStatement public override string ConnectionString { get => GetConnectionString(); set => SetConnectionProperties(value); } + /// + /// Gets or sets the behavior of decimals. + /// + public DecimalBehavior DecimalBehavior { get; set; } + protected override DbCommand CreateDbCommand() { EnsureConnectionOpen(); @@ -220,9 +226,6 @@ public override DataTable GetSchema() return GetSchema(null); } - //GetSchema("TABLES") - //GetSchema("VIEWS") - public override DataTable GetSchema(string collectionName) { return GetSchema(collectionName, null); @@ -231,7 +234,7 @@ public override DataTable GetSchema(string collectionName) public override DataTable GetSchema(string collectionName, string[] restrictionValues) { Schema arrowSchema = this.adbcConnectionInternal.GetTableSchema("", "", ""); - return SchemaConverter.ConvertArrowSchema(arrowSchema, this.AdbcStatement); + return SchemaConverter.ConvertArrowSchema(arrowSchema, this.AdbcStatement, this.DecimalBehavior); } #region NOT_IMPLEMENTED diff --git a/csharp/src/Client/AdbcDataReader.cs b/csharp/src/Client/AdbcDataReader.cs index d5772da579..6e9a7ae49c 100644 --- a/csharp/src/Client/AdbcDataReader.cs +++ b/csharp/src/Client/AdbcDataReader.cs @@ -21,6 +21,7 @@ using System.Collections.ObjectModel; using System.Data; using System.Data.Common; +using System.Data.SqlTypes; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading; @@ -42,7 +43,7 @@ public sealed class AdbcDataReader : DbDataReader, IDbColumnSchemaGenerator private bool isClosed; private int recordsEffected = -1; - internal AdbcDataReader(AdbcCommand adbcCommand, QueryResult adbcQueryResult) + internal AdbcDataReader(AdbcCommand adbcCommand, QueryResult adbcQueryResult, DecimalBehavior decimalBehavior) { if (adbcCommand == null) throw new ArgumentNullException(nameof(adbcCommand)); @@ -58,6 +59,7 @@ internal AdbcDataReader(AdbcCommand adbcCommand, QueryResult adbcQueryResult) throw new ArgumentException("A Schema must be set for the AdbcQueryResult.Stream property"); this.isClosed = false; + this.DecimalBehavior = decimalBehavior; } public override object this[int ordinal] => GetValue(ordinal); @@ -77,6 +79,8 @@ internal AdbcDataReader(AdbcCommand adbcCommand, QueryResult adbcQueryResult) /// public Schema ArrowSchema => this.schema; + public DecimalBehavior DecimalBehavior { get; set; } + public override int RecordsAffected => this.recordsEffected; /// @@ -128,12 +132,24 @@ public override DateTime GetDateTime(int ordinal) public override decimal GetDecimal(int ordinal) { - return (decimal) GetValue(ordinal); + return Convert.ToDecimal(GetValue(ordinal)); + } + + public SqlDecimal GetSqlDecimal(int ordinal) + { + if (this.DecimalBehavior == DecimalBehavior.UseSqlDecimal) + { + return (SqlDecimal)GetValue(ordinal); + } + else + { + throw new InvalidOperationException("Cannot convert to SqlDecimal if DecimalBehavior.UseSqlDecimal is not configured"); + } } public override double GetDouble(int ordinal) { - return (double) GetValue(ordinal); + return Convert.ToDouble(GetValue(ordinal)); } public override IEnumerator GetEnumerator() @@ -196,7 +212,31 @@ public override string GetString(int ordinal) public override object GetValue(int ordinal) { - return GetValue(this.recordBatch?.Column(ordinal), ordinal); + object value = GetValue(this.recordBatch?.Column(ordinal), ordinal); + + if (value == null) + return null; + + if(value is SqlDecimal dValue) + { + if (this.DecimalBehavior == DecimalBehavior.UseSqlDecimal) + { + return dValue; + } + else + { + try + { + return dValue.Value; + } + catch(OverflowException) + { + return dValue.ToString(); + } + } + } + + return value; } public override int GetValues(object[] values) @@ -248,7 +288,7 @@ public override DataTable GetSchemaTable() { if (this.schema != null) { - return SchemaConverter.ConvertArrowSchema(this.schema, this.adbcCommand.AdbcStatement); + return SchemaConverter.ConvertArrowSchema(this.schema, this.adbcCommand.AdbcStatement, this.DecimalBehavior); } else { @@ -275,7 +315,7 @@ public ReadOnlyCollection GetAdbcColumnSchema() foreach (Field f in this.schema.FieldsList) { - Type t = SchemaConverter.ConvertArrowType(f); + Type t = SchemaConverter.ConvertArrowType(f, this.DecimalBehavior); if(f.HasMetadata && f.Metadata.ContainsKey("precision") && diff --git a/csharp/src/Client/DecimalBehavior.cs b/csharp/src/Client/DecimalBehavior.cs new file mode 100644 index 0000000000..c504b7701c --- /dev/null +++ b/csharp/src/Client/DecimalBehavior.cs @@ -0,0 +1,43 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Data.SqlTypes; + +namespace Apache.Arrow.Adbc.Client +{ + /// + /// Some callers may choose to leverage the default + /// values to get a full range of precision + /// and scale support, while others may opt to restrict + /// use to .NET's range + /// + public enum DecimalBehavior + { + /// + /// Use + /// + UseSqlDecimal, + + /// + /// Use + /// and treat + /// as string values + /// + OverflowDecimalAsString + } +} diff --git a/csharp/src/Client/SchemaConverter.cs b/csharp/src/Client/SchemaConverter.cs index 7e38945f12..7f6d171265 100644 --- a/csharp/src/Client/SchemaConverter.cs +++ b/csharp/src/Client/SchemaConverter.cs @@ -18,6 +18,7 @@ using System; using System.Data; using System.Data.Common; +using System.Data.SqlTypes; using Apache.Arrow.Types; namespace Apache.Arrow.Adbc.Client @@ -31,7 +32,7 @@ internal class SchemaConverter /// The Arrow schema /// The AdbcStatement to use /// - public static DataTable ConvertArrowSchema(Schema schema, AdbcStatement adbcStatement) + public static DataTable ConvertArrowSchema(Schema schema, AdbcStatement adbcStatement, DecimalBehavior decimalBehavior) { if(schema == null) throw new ArgumentNullException(nameof(schema)); @@ -59,7 +60,7 @@ public static DataTable ConvertArrowSchema(Schema schema, AdbcStatement adbcStat row[SchemaTableColumn.ColumnOrdinal] = columnOrdinal; row[SchemaTableColumn.AllowDBNull] = f.IsNullable; row[SchemaTableColumn.ProviderType] = f.DataType; - Type t = ConvertArrowType(f); + Type t = ConvertArrowType(f, decimalBehavior); row[SchemaTableColumn.DataType] = t; @@ -86,11 +87,11 @@ public static DataTable ConvertArrowSchema(Schema schema, AdbcStatement adbcStat } /// - /// Convert types for Snowflake only + /// Convert types /// /// /// - public static Type ConvertArrowType(Field f) + public static Type ConvertArrowType(Field f, DecimalBehavior decimalBehavior) { switch (f.DataType.TypeId) { @@ -101,8 +102,13 @@ public static Type ConvertArrowType(Field f) return typeof(bool); case ArrowTypeId.Decimal128: + if(decimalBehavior == DecimalBehavior.UseSqlDecimal) + return typeof(SqlDecimal); + else + return typeof(decimal); + case ArrowTypeId.Decimal256: - return typeof(decimal); + return typeof(string); case ArrowTypeId.Time32: case ArrowTypeId.Time64: diff --git a/csharp/src/Client/readme.md b/csharp/src/Client/readme.md index 91e33ca96d..2c4a51aca3 100644 --- a/csharp/src/Client/readme.md +++ b/csharp/src/Client/readme.md @@ -23,20 +23,20 @@ The Client library provides an ADO.NET client over the the top of results from t ## Library Design The Client is designed to work with any driver that inherits from [AdbcDriver](https://github.com/apache/arrow-adbc/blob/main/csharp/src/Apache.Arrow.Adbc/AdbcDriver.cs), whether they are written in a .NET language or a C-compatible language that can be loaded via Interop. -The driver is injected at runtime during the creation of the `Client.AdbcConnection`, seen here: +The driver is injected at runtime during the creation of the `Adbc.Client.AdbcConnection`, seen here: ![Dependency Injection Model](/docs/DependencyInjection.png "Dependency Injection Model") This enables the client to work with multiple ADBC drivers in the same fashion. When a new client AdbcConnection is created, the driver is just passed in as part of the constructor, like: ``` -new Client.AdbcConnection() +new Adbc.Client.AdbcConnection() { new DriverA(), ... } -new Client.AdbcConnection() +new Adbc.Client.AdbcConnection() { new DriverB(), ... diff --git a/csharp/src/Drivers/BigQuery/BigQueryStatement.cs b/csharp/src/Drivers/BigQuery/BigQueryStatement.cs index 64f6d09a2f..bcb0c97464 100644 --- a/csharp/src/Drivers/BigQuery/BigQueryStatement.cs +++ b/csharp/src/Drivers/BigQuery/BigQueryStatement.cs @@ -20,7 +20,6 @@ using System.IO; using System.Linq; using System.Text.Json; -using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Ipc; @@ -78,63 +77,6 @@ public override UpdateResult ExecuteUpdate() return new UpdateResult(updatedRows); } - public override object GetValue(IArrowArray arrowArray, Field field, int index) - { - if(arrowArray == null) throw new ArgumentNullException(nameof(arrowArray)); - if (field == null) throw new ArgumentNullException(nameof(field)); - if (index < 0) throw new ArgumentOutOfRangeException(nameof(index)); - - try - { - switch (arrowArray) - { - case Int64Array int64Array: - return int64Array.GetValue(index); - case DoubleArray doubleArray: - return doubleArray.GetValue(index); - case Decimal128Array decimal128Array: - return decimal128Array.GetValue(index); - case Decimal256Array decimal256Array: - return decimal256Array.GetValue(index); - case BooleanArray booleanArray: - return booleanArray.GetValue(index); - case StringArray stringArray: - return stringArray.GetString(index); - case BinaryArray binaryArray: - - if(!binaryArray.IsNull(index)) - return binaryArray.GetBytes(index).ToArray(); - - return null; - - case Date32Array date32Array: - return date32Array.GetDateTime(index); - case Date64Array date64Array: - return date64Array.GetDateTime(index); - case Time64Array time64Array: - return time64Array.GetValue(index); - case TimestampArray timestampArray: - DateTimeOffset? dateTimeOffset = timestampArray.GetTimestamp(index); - - if(dateTimeOffset != null) - return dateTimeOffset.Value; - - return null; - case StructArray structArray: - return SerializeToJson(structArray, index); - // maybe not be needed? - case ListArray listArray: - return listArray.GetSlicedValues(index); - } - } - catch (OverflowException oex) - { - return ParseDecimalValueFromOverflowException(oex); - } - - return null; - } - private Schema TranslateSchema(TableSchema schema) { return new Schema(schema.Fields.Select(TranslateField), null); @@ -238,30 +180,6 @@ private string SerializeToJson(StructArray structArray, int index) return JsonSerializer.Serialize(jsonDictionary); } - private string ParseDecimalValueFromOverflowException(OverflowException oex) - { - if (oex == null) - throw new ArgumentNullException(nameof(oex)); - - // any decimal value, positive or negative, with or without a decimal in place - Regex regex = new Regex(" -?\\d*\\.?\\d* "); - - MatchCollection matches = regex.Matches(oex.Message); - - // need the second value from the message - if (matches.Count == 2) - { - string value = matches[1].Value; - - if (!string.IsNullOrEmpty(value)) - { - return value.Trim(); - } - } - - throw oex; - } - class MultiArrowReader : IArrowArrayStream { readonly Schema schema; diff --git a/csharp/src/Drivers/BigQuery/readme.md b/csharp/src/Drivers/BigQuery/readme.md index 41ac68646f..310cf32dae 100644 --- a/csharp/src/Drivers/BigQuery/readme.md +++ b/csharp/src/Drivers/BigQuery/readme.md @@ -67,7 +67,7 @@ The following table depicts how the BigQuery ADBC driver converts a BigQuery typ | BigQuery Type | Arrow Type | C# Type |----------|:-------------:| -| BIGNUMERIC | Decimal256 | decimal / string* +| BIGNUMERIC | Decimal256 | string | BOOL | Boolean | bool | BYTES | Binary | byte[] | DATE | Date64 | DateTime @@ -75,14 +75,12 @@ The following table depicts how the BigQuery ADBC driver converts a BigQuery typ | FLOAT64 | Double | double | GEOGRAPHY | String | string | INT64 | Int64 | long -| NUMERIC | Decimal128 | decimal / string* +| NUMERIC | Decimal128 | SqlDecimal | STRING | String | string | STRUCT | String+ | string | TIME |Time64 | long | TIMESTAMP | Timestamp | DateTimeOffset -*An attempt is made to parse the original value as a `decimal` in C#. If that fails, the driver attempts to parse the overflow exception and return the original value. - +A JSON string See [Arrow Schema Details](https://cloud.google.com/bigquery/docs/reference/storage/#arrow_schema_details) for how BigQuery handles Arrow types. diff --git a/csharp/src/Drivers/FlightSql/FlightSqlStatement.cs b/csharp/src/Drivers/FlightSql/FlightSqlStatement.cs index 2ed04f7a37..5e80974727 100644 --- a/csharp/src/Drivers/FlightSql/FlightSqlStatement.cs +++ b/csharp/src/Drivers/FlightSql/FlightSqlStatement.cs @@ -16,7 +16,6 @@ */ using System; -using System.Text.RegularExpressions; using System.Threading.Tasks; using Apache.Arrow.Flight; using Grpc.Core; @@ -58,176 +57,5 @@ public async ValueTask GetInfo(string query, Metadata headers) return await _flightSqlConnection.FlightClient.GetInfo(commandDescripter, headers).ResponseAsync; } - - /// - /// Gets a value from the Arrow array at the specified index - /// using the Arrow field for metadata. - /// - /// - /// The array containing the value. - /// - /// - /// The Arrow field. - /// - /// - /// The index of the item. - /// - /// - /// The item at the index position. - /// - public override object GetValue(IArrowArray arrowArray, Field field, int index) - { - if (arrowArray is BooleanArray) - { - return Convert.ToBoolean(((BooleanArray)arrowArray).Values[index]); - } - else if (arrowArray is Date32Array) - { - Date32Array date32Array = (Date32Array)arrowArray; - - return date32Array.GetDateTime(index); - } - else if (arrowArray is Date64Array) - { - Date64Array date64Array = (Date64Array)arrowArray; - - return date64Array.GetDateTime(index); - } - else if (arrowArray is Decimal128Array) - { - try - { - // the value may be decimal.max - // then Arrow throws an exception - // no good way to check prior to - return ((Decimal128Array)arrowArray).GetValue(index); - } - catch (OverflowException oex) - { - return ParseDecimalValueFromOverflowException(oex); - } - } - else if (arrowArray is Decimal256Array) - { - try - { - return ((Decimal256Array)arrowArray).GetValue(index); - } - catch (OverflowException oex) - { - return ParseDecimalValueFromOverflowException(oex); - } - } - else if (arrowArray is DoubleArray) - { - return ((DoubleArray)arrowArray).GetValue(index); - } - else if (arrowArray is FloatArray) - { - return ((FloatArray)arrowArray).GetValue(index); - } -#if NET5_0_OR_GREATER - else if (arrowArray is PrimitiveArray) - { - // TODO: HalfFloatArray not present in current library - - return ((PrimitiveArray)arrowArray).GetValue(index); - } -#endif - else if (arrowArray is Int8Array) - { - return ((Int8Array)arrowArray).GetValue(index); - } - else if (arrowArray is Int16Array) - { - return ((Int16Array)arrowArray).GetValue(index); - } - else if (arrowArray is Int32Array) - { - return ((Int32Array)arrowArray).GetValue(index); - } - else if (arrowArray is Int64Array) - { - Int64Array array = (Int64Array)arrowArray; - return array.GetValue(index); - } - else if (arrowArray is StringArray) - { - return ((StringArray)arrowArray).GetString(index); - } - else if (arrowArray is Time32Array) - { - return ((Time32Array)arrowArray).GetValue(index); - } - else if (arrowArray is Time64Array) - { - return ((Time64Array)arrowArray).GetValue(index); - } - else if (arrowArray is TimestampArray) - { - TimestampArray timestampArray = (TimestampArray)arrowArray; - DateTimeOffset dateTimeOffset = timestampArray.GetTimestamp(index).Value; - return dateTimeOffset; - } - else if (arrowArray is UInt8Array) - { - return ((UInt8Array)arrowArray).GetValue(index); - } - else if (arrowArray is UInt16Array) - { - return ((UInt16Array)arrowArray).GetValue(index); - } - else if (arrowArray is UInt32Array) - { - return ((UInt32Array)arrowArray).GetValue(index); - } - else if (arrowArray is UInt64Array) - { - return ((UInt64Array)arrowArray).GetValue(index); - } - - // not covered: - // -- struct array - // -- binary array - // -- dictionary array - // -- fixed size binary - // -- list array - // -- union array - - return null; - } - - /// - /// For decimals, Arrow throws an OverflowException if a value - /// is < decimal.min or > decimal.max - /// So parse the numeric value and return it as a string, - /// if possible - /// - /// The OverflowException - /// - /// A string value of the decimal that threw the exception - /// or rethrows the OverflowException. - /// - /// - private string ParseDecimalValueFromOverflowException(OverflowException oex) - { - if (oex == null) - throw new ArgumentNullException(nameof(oex)); - - // any decimal value, positive or negative, with or without a decimal in place - Regex regex = new Regex(" -?\\d*\\.?\\d* "); - - var matches = regex.Matches(oex.Message); - - foreach (Match match in matches) - { - string value = match.Value; - - if (!string.IsNullOrEmpty(value)) - return value; - } - - throw oex; - } } } diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/AdbcTests.cs b/csharp/test/Apache.Arrow.Adbc.Tests/AdbcTests.cs index 86fa25393c..fe5812c698 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/AdbcTests.cs +++ b/csharp/test/Apache.Arrow.Adbc.Tests/AdbcTests.cs @@ -21,7 +21,6 @@ using System.Linq; using System.Reflection; using Xunit; -using Assert = Microsoft.VisualStudio.TestTools.UnitTesting.Assert; namespace Apache.Arrow.Adbc.Tests { @@ -72,7 +71,7 @@ public void ValidateInfoCodes(AdbcInfoCode code, string adbcName, int value) /// The value of the ADBC value. private void ValidateEnumValue(int enumValue, string adbcName, int value) { - Assert.AreEqual(enumValue, value); + Assert.Equal(enumValue, value); // find the corresponding value in adbc.h and validate it string path = GetPathForAdbcH(); @@ -81,11 +80,11 @@ private void ValidateEnumValue(int enumValue, string adbcName, int value) string line = File.ReadAllLines(path).Where(x => x.StartsWith(pattern)).FirstOrDefault(); - Assert.IsFalse(string.IsNullOrEmpty(line)); + Assert.False(string.IsNullOrEmpty(line)); string definedValue = line.Replace(pattern, "").Trim(); - Assert.AreEqual(value, Convert.ToInt32(definedValue)); + Assert.Equal(value, Convert.ToInt32(definedValue)); } // C# is designed to match Java's AdbcDriver @@ -171,9 +170,9 @@ private void ValidateMethod(Type t, string methodName, string[] parameterNames = if (parameterNames != null) { - Assert.IsTrue(parameterNames.Length > 0); - Assert.IsTrue(parameterTypes != null); - Assert.AreEqual(parameterNames.Length, parameterTypes.Length); + Assert.True(parameterNames.Length > 0); + Assert.True(parameterTypes != null); + Assert.Equal(parameterNames.Length, parameterTypes.Length); ParameterInfo[] parameters = mi.GetParameters(); @@ -181,8 +180,8 @@ private void ValidateMethod(Type t, string methodName, string[] parameterNames = { ParameterInfo parameter = parameters[i]; - Assert.AreEqual(parameter.Name, parameterNames[i]); - Assert.AreEqual(parameter.ParameterType, parameterTypes[i]); + Assert.Equal(parameter.Name, parameterNames[i]); + Assert.Equal(parameter.ParameterType, parameterTypes[i]); } } } @@ -198,7 +197,7 @@ private void ValidateProperty(Type t, string propertyName, Type propertyType) { PropertyInfo pi = t.GetProperty(propertyName, propertyType); - Assert.IsNotNull(pi); + Assert.NotNull(pi); } private string GetPathForAdbcH() @@ -207,7 +206,7 @@ private string GetPathForAdbcH() string path = Path.Combine(new string[] { "..", "..", "..", "..", "..", "adbc.h" }); - Assert.IsTrue(File.Exists(path)); + Assert.True(File.Exists(path)); return path; } diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/Apache.Arrow.Adbc.Tests.csproj b/csharp/test/Apache.Arrow.Adbc.Tests/Apache.Arrow.Adbc.Tests.csproj index e0bb70461b..eb53f350ac 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/Apache.Arrow.Adbc.Tests.csproj +++ b/csharp/test/Apache.Arrow.Adbc.Tests/Apache.Arrow.Adbc.Tests.csproj @@ -10,15 +10,14 @@ - - - - + + + diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/Client/ClientTests.cs b/csharp/test/Apache.Arrow.Adbc.Tests/Client/ClientTests.cs new file mode 100644 index 0000000000..9548b85e4a --- /dev/null +++ b/csharp/test/Apache.Arrow.Adbc.Tests/Client/ClientTests.cs @@ -0,0 +1,128 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Client; +using Apache.Arrow.Ipc; +using Apache.Arrow.Types; +using Moq; +using Xunit; + +namespace Apache.Arrow.Adbc.Tests.Client +{ + public class ClientTests + { + [Theory] + [InlineData(DecimalBehavior.OverflowDecimalAsString, "79228162514264337593543950335", 29, 0, typeof(decimal))] + [InlineData(DecimalBehavior.OverflowDecimalAsString, "792281625142643375935439503351", 30, 0, typeof(string))] + [InlineData(DecimalBehavior.UseSqlDecimal, "792281625142643375935439503351", 30, 0, typeof(SqlDecimal))] + public void TestDecimalValues(DecimalBehavior decimalBehavior, string value, int precision, int scale, Type expectedType) + { + AdbcDataReader rdr = GetMoqDataReader(decimalBehavior, value, precision, scale); + object rdrValue = rdr.GetValue(0); + + Assert.True(rdrValue.GetType().Equals(expectedType)); + } + + private AdbcDataReader GetMoqDataReader(DecimalBehavior decimalBehavior, string value, int precision, int scale) + { + SqlDecimal sqlDecimal = SqlDecimal.Parse(value); + + List> metadata = new List>(); + metadata.Add(new KeyValuePair("precision", precision.ToString())); + metadata.Add(new KeyValuePair("scale", scale.ToString())); + + List fields = new List(); + fields.Add(new Field("Decimal128t", new Decimal128Type(precision, scale), true, metadata)); + + Schema schema = new Schema(fields, metadata); + Decimal128Array.Builder builder = new Decimal128Array.Builder(new Decimal128Type(precision, scale)); + builder.Append(sqlDecimal); + Decimal128Array array = builder.Build(); + + List values = new List() { array }; + + List records = new List() + { + new RecordBatch(schema, values, values.Count) + }; + + MockArrayStream mockArrayStream = new MockArrayStream(schema, records); + QueryResult queryResult = new QueryResult(1, mockArrayStream); + + Mock mockStatement = new Mock(); + mockStatement.Setup(x => x.ExecuteQuery()).Returns(queryResult); ; + mockStatement.Setup(x => x.GetValue(It.IsAny(), It.IsAny(), It.IsAny())).Returns(sqlDecimal); + + Adbc.Client.AdbcConnection mockConnection = new Adbc.Client.AdbcConnection(); + mockConnection.DecimalBehavior = decimalBehavior; + + AdbcCommand cmd = new AdbcCommand(mockStatement.Object, mockConnection); + + AdbcDataReader reader = cmd.ExecuteReader(); + return reader; + } + } + + class MockArrayStream : IArrowArrayStream + { + private readonly List recordBatches; + private readonly Schema schema; + + // start at -1 to use the count the number of calls as the index + private int calls = -1; + + /// + /// Initializes the TestArrayStream. + /// + /// + /// The Arrow schema. + /// + /// + /// A list of record batches. + /// + public MockArrayStream(Schema schema, List recordBatches) + { + this.schema = schema; + this.recordBatches = recordBatches; + } + + public Schema Schema => this.schema; + + public void Dispose() { } + + /// + /// Moves through the list of record batches. + /// + /// + /// Optional cancellation token. + /// + public ValueTask ReadNextRecordBatchAsync(CancellationToken cancellationToken = default) + { + calls++; + + if (calls >= this.recordBatches.Count) + return new ValueTask(); + else + return new ValueTask(this.recordBatches[calls]); + } + } +} diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs b/csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs index 94428cad78..1ecddb4b2f 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs +++ b/csharp/test/Apache.Arrow.Adbc.Tests/ClientTests.cs @@ -22,7 +22,7 @@ using System.Data.Common; using System.Linq; using Apache.Arrow.Types; -using Microsoft.VisualStudio.TestTools.UnitTesting; +using Xunit; namespace Apache.Arrow.Adbc.Tests { @@ -53,17 +53,16 @@ public static void AssertTypeAndValue(ColumnNetTypeArrowTypeValue ctv, object va { dataTableType = row[SchemaTableColumn.DataType] as Type; arrowType = row[SchemaTableColumn.ProviderType] as IArrowType; - } } Type netType = reader[name]?.GetType(); - Assert.IsTrue(clientArrowType == ctv.ExpectedNetType, $"{name} is {clientArrowType.Name} and not {ctv.ExpectedNetType.Name} in the column schema"); + Assert.True(clientArrowType == ctv.ExpectedNetType, $"{name} is {clientArrowType.Name} and not {ctv.ExpectedNetType.Name} in the column schema"); - Assert.IsTrue(dataTableType == ctv.ExpectedNetType, $"{name} is {dataTableType.Name} and not {ctv.ExpectedNetType.Name} in the data table"); + Assert.True(dataTableType == ctv.ExpectedNetType, $"{name} is {dataTableType.Name} and not {ctv.ExpectedNetType.Name} in the data table"); - Assert.IsTrue(arrowType.GetType() == ctv.ExpectedArrowArrayType, $"{name} is {arrowType.Name} and not {ctv.ExpectedArrowArrayType.Name} in the provider type"); + Assert.True(arrowType.GetType() == ctv.ExpectedArrowArrayType, $"{name} is {arrowType.Name} and not {ctv.ExpectedArrowArrayType.Name} in the provider type"); if (netType != null) { @@ -75,7 +74,7 @@ public static void AssertTypeAndValue(ColumnNetTypeArrowTypeValue ctv, object va { object internalValue = value.GetType().GetMethod("GetValue").Invoke(value, new object[] { 0 }); - Assert.IsTrue(internalValue.GetType() == ctv.ExpectedNetType, $"{name} is {netType.Name} and not {ctv.ExpectedNetType.Name} in the reader"); + Assert.True(internalValue.GetType() == ctv.ExpectedNetType, $"{name} is {netType.Name} and not {ctv.ExpectedNetType.Name} in the reader"); } else { @@ -84,7 +83,7 @@ public static void AssertTypeAndValue(ColumnNetTypeArrowTypeValue ctv, object va } else { - Assert.IsTrue(netType == ctv.ExpectedNetType, $"{name} is {netType.Name} and not {ctv.ExpectedNetType.Name} in the reader"); + Assert.True(netType == ctv.ExpectedNetType, $"{name} is {netType.Name} and not {ctv.ExpectedNetType.Name} in the reader"); } } @@ -92,18 +91,18 @@ public static void AssertTypeAndValue(ColumnNetTypeArrowTypeValue ctv, object va { if (!value.GetType().BaseType.Name.Contains("PrimitiveArray")) { - Assert.AreEqual(ctv.ExpectedNetType, value.GetType(), $"Expected type does not match actual type for {ctv.Name}"); + Assert.True(ctv.ExpectedNetType == value.GetType(), $"Expected type does not match actual type for {ctv.Name}"); if (value.GetType() == typeof(byte[])) { byte[] actualBytes = (byte[])value; byte[] expectedBytes = (byte[])ctv.ExpectedValue; - Assert.IsTrue(actualBytes.SequenceEqual(expectedBytes), $"byte[] values do not match expected values for {ctv.Name}"); + Assert.True(actualBytes.SequenceEqual(expectedBytes), $"byte[] values do not match expected values for {ctv.Name}"); } else { - Assert.AreEqual(ctv.ExpectedValue, value, $"Expected value does not match actual value for {ctv.Name}"); + Assert.True(ctv.ExpectedValue.Equals(value), $"Expected value does not match actual value for {ctv.Name}"); } } else @@ -125,7 +124,7 @@ public static void AssertTypeAndValue(ColumnNetTypeArrowTypeValue ctv, object va if (i == j) { - Assert.AreEqual(expected, actual, $"Expected value does not match actual value for {ctv.Name} at {i}"); + Assert.True(expected.Equals(actual), $"Expected value does not match actual value for {ctv.Name} at {i}"); } } } diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/DriverTests.cs b/csharp/test/Apache.Arrow.Adbc.Tests/DriverTests.cs index 14b7dd1099..70c0acac58 100644 --- a/csharp/test/Apache.Arrow.Adbc.Tests/DriverTests.cs +++ b/csharp/test/Apache.Arrow.Adbc.Tests/DriverTests.cs @@ -15,7 +15,7 @@ * limitations under the License. */ -using Microsoft.VisualStudio.TestTools.UnitTesting; +using Xunit; namespace Apache.Arrow.Adbc.Tests { @@ -45,14 +45,14 @@ public static void CanExecuteQuery(QueryResult queryResult, long expectedNumberO count += nextBatch.Length; } - Assert.AreEqual(expectedNumberOfResults, count, "The parsed records differ from the specified amount"); + Assert.True(expectedNumberOfResults == count, "The parsed records differ from the specified amount"); // if the values were set, make sure they are correct if (queryResult.RowCount != -1) { - Assert.AreEqual(queryResult.RowCount, expectedNumberOfResults, "The RowCount value does not match the expected results"); + Assert.True(queryResult.RowCount == expectedNumberOfResults, "The RowCount value does not match the expected results"); - Assert.AreEqual(queryResult.RowCount, count, "The RowCount value does not match the counted records"); + Assert.True(queryResult.RowCount == count, "The RowCount value does not match the counted records"); } } } diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/Xunit/OrderAttribute.cs b/csharp/test/Apache.Arrow.Adbc.Tests/Xunit/OrderAttribute.cs new file mode 100644 index 0000000000..7831fdf96f --- /dev/null +++ b/csharp/test/Apache.Arrow.Adbc.Tests/Xunit/OrderAttribute.cs @@ -0,0 +1,32 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; + +namespace Apache.Arrow.Adbc.Tests.Xunit +{ + /// + /// Used to specify the order of Xunit tests. + /// + [AttributeUsage(AttributeTargets.Method, AllowMultiple = false)] + public class OrderAttribute : Attribute + { + public int Order { get; private set; } + + public OrderAttribute(int order) => Order = order; + } +} diff --git a/csharp/test/Apache.Arrow.Adbc.Tests/Xunit/TestOrderer.cs b/csharp/test/Apache.Arrow.Adbc.Tests/Xunit/TestOrderer.cs new file mode 100644 index 0000000000..5c30542b10 --- /dev/null +++ b/csharp/test/Apache.Arrow.Adbc.Tests/Xunit/TestOrderer.cs @@ -0,0 +1,62 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Linq; +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace Apache.Arrow.Adbc.Tests.Xunit +{ + /// + /// Orders the tests in a class to ensure they execute correctly. + /// + public class TestOrderer : ITestCaseOrderer + { + public IEnumerable OrderTestCases( + IEnumerable testCases) where TTestCase : ITestCase + { + string assemblyName = typeof(OrderAttribute).AssemblyQualifiedName!; + var sortedMethods = new SortedDictionary>(); + foreach (TTestCase testCase in testCases) + { + int priority = testCase.TestMethod.Method + .GetCustomAttributes(assemblyName) + .FirstOrDefault() + ?.GetNamedArgument(nameof(OrderAttribute.Order)) ?? 0; + + GetOrCreate(sortedMethods, priority).Add(testCase); + } + + foreach (TTestCase testCase in + sortedMethods.Keys.SelectMany( + priority => sortedMethods[priority].OrderBy( + testCase => testCase.TestMethod.Method.Name))) + { + yield return testCase; + } + } + + private static TValue GetOrCreate( + IDictionary dictionary, TKey key) + where TKey : struct + where TValue : new() => + dictionary.TryGetValue(key, out TValue result) + ? result + : (dictionary[key] = new TValue()); + } +} diff --git a/csharp/test/Drivers/BigQuery/Apache.Arrow.Adbc.Tests.Drivers.BigQuery.csproj b/csharp/test/Drivers/BigQuery/Apache.Arrow.Adbc.Tests.Drivers.BigQuery.csproj index 826e8af8f7..d4c2fb541b 100644 --- a/csharp/test/Drivers/BigQuery/Apache.Arrow.Adbc.Tests.Drivers.BigQuery.csproj +++ b/csharp/test/Drivers/BigQuery/Apache.Arrow.Adbc.Tests.Drivers.BigQuery.csproj @@ -3,12 +3,13 @@ net472;net6.0 - - - - - - + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + diff --git a/csharp/test/Drivers/BigQuery/ClientTests.cs b/csharp/test/Drivers/BigQuery/ClientTests.cs index 8989518615..36d2f19f56 100644 --- a/csharp/test/Drivers/BigQuery/ClientTests.cs +++ b/csharp/test/Drivers/BigQuery/ClientTests.cs @@ -21,54 +21,64 @@ using System.Data.Common; using Apache.Arrow.Adbc.Client; using Apache.Arrow.Adbc.Drivers.BigQuery; -using NUnit.Framework; +using Apache.Arrow.Adbc.Tests.Xunit; +using Xunit; namespace Apache.Arrow.Adbc.Tests.Drivers.BigQuery { /// /// Class for testing the ADBC Client using the BigQuery ADBC driver. /// - [TestFixture] + /// /// + /// Tests are ordered to ensure data is created for the other + /// queries to run. + /// + [TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer", "Apache.Arrow.Adbc.Tests")] public class ClientTests { + public ClientTests() + { + Skip.IfNot(Utils.CanExecuteTestConfig(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE)); + } + /// /// Validates if the client execute updates. /// - [Test, Order(1)] + [SkippableFact, Order(1)] public void CanClientExecuteUpdate() { - if (Utils.CanExecuteTestConfig(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE)) - { - BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); + BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); - using (Client.AdbcConnection adbcConnection = GetAdbcConnection(testConfiguration)) - { - adbcConnection.Open(); + using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection(testConfiguration)) + { + adbcConnection.Open(); - string[] queries = BigQueryTestingUtils.GetQueries(testConfiguration); + string[] queries = BigQueryTestingUtils.GetQueries(testConfiguration); - List expectedResults = new List() { -1, 1, 1, 1 }; + List expectedResults = new List() { -1, 1, 1, 1 }; - for (int i = 0; i < queries.Length; i++) - { - string query = queries[i]; - AdbcCommand adbcCommand = adbcConnection.CreateCommand(); - adbcCommand.CommandText = query; + for (int i = 0; i < queries.Length; i++) + { + string query = queries[i]; + AdbcCommand adbcCommand = adbcConnection.CreateCommand(); + adbcCommand.CommandText = query; - int rows = adbcCommand.ExecuteNonQuery(); + int rows = adbcCommand.ExecuteNonQuery(); - Assert.AreEqual(expectedResults[i], rows, $"The expected affected rows do not match the actual affected rows at position {i}."); - } + Assert.Equal(expectedResults[i], rows); } } } - [Test, Order(2)] + /// + /// Validates if the client can get the schema. + /// + [SkippableFact, Order(2)] public void CanClientGetSchema() { BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); - using (Client.AdbcConnection adbcConnection = GetAdbcConnection(testConfiguration)) + using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection(testConfiguration)) { AdbcCommand adbcCommand = new AdbcCommand(testConfiguration.Query, adbcConnection); @@ -79,7 +89,7 @@ public void CanClientGetSchema() DataTable table = reader.GetSchemaTable(); // there is one row per field - Assert.AreEqual(testConfiguration.Metadata.ExpectedColumnCount, table.Rows.Count); + Assert.Equal(testConfiguration.Metadata.ExpectedColumnCount, table.Rows.Count); } } @@ -87,87 +97,81 @@ public void CanClientGetSchema() /// Validates if the client can connect to a live server and /// parse the results. /// - [Test, Order(3)] + [SkippableFact, Order(3)] public void CanClientExecuteQuery() { - if (Utils.CanExecuteTestConfig(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE)) - { - BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); + BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); - long count = 0; + long count = 0; - using (Client.AdbcConnection adbcConnection = GetAdbcConnection(testConfiguration)) - { - AdbcCommand adbcCommand = new AdbcCommand(testConfiguration.Query, adbcConnection); + using (Adbc.Client.AdbcConnection adbcConnection = GetAdbcConnection(testConfiguration)) + { + AdbcCommand adbcCommand = new AdbcCommand(testConfiguration.Query, adbcConnection); - adbcConnection.Open(); + adbcConnection.Open(); - AdbcDataReader reader = adbcCommand.ExecuteReader(); + AdbcDataReader reader = adbcCommand.ExecuteReader(); - try + try + { + while (reader.Read()) { - while (reader.Read()) - { - count++; + count++; - for(int i=0;i /// Validates if the client is retrieving and converting values /// to the expected types. /// - [Test, Order(4)] + [SkippableFact, Order(4)] public void VerifyTypesAndValues() { - if (Utils.CanExecuteTestConfig(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE)) - { - BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); + BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); - Client.AdbcConnection dbConnection = GetAdbcConnection(testConfiguration); + Adbc.Client.AdbcConnection dbConnection = GetAdbcConnection(testConfiguration); - dbConnection.Open(); - DbCommand dbCommand = dbConnection.CreateCommand(); - dbCommand.CommandText = testConfiguration.Query; + dbConnection.Open(); + DbCommand dbCommand = dbConnection.CreateCommand(); + dbCommand.CommandText = testConfiguration.Query; - DbDataReader reader = dbCommand.ExecuteReader(CommandBehavior.Default); + DbDataReader reader = dbCommand.ExecuteReader(CommandBehavior.Default); - if (reader.Read()) - { - var column_schema = reader.GetColumnSchema(); - DataTable dataTable = reader.GetSchemaTable(); + if (reader.Read()) + { + var column_schema = reader.GetColumnSchema(); + DataTable dataTable = reader.GetSchemaTable(); - List expectedValues = SampleData.GetSampleData(); + List expectedValues = SampleData.GetSampleData(); - for (int i = 0; i < reader.FieldCount; i++) - { - object value = reader.GetValue(i); - ColumnNetTypeArrowTypeValue ctv = expectedValues[i]; + for (int i = 0; i < reader.FieldCount; i++) + { + object value = reader.GetValue(i); + ColumnNetTypeArrowTypeValue ctv = expectedValues[i]; - Tests.ClientTests.AssertTypeAndValue(ctv, value, reader, column_schema, dataTable); - } + Tests.ClientTests.AssertTypeAndValue(ctv, value, reader, column_schema, dataTable); } } } - private Client.AdbcConnection GetAdbcConnection(BigQueryTestConfiguration testConfiguration) + private Adbc.Client.AdbcConnection GetAdbcConnection(BigQueryTestConfiguration testConfiguration) { - return new Client.AdbcConnection( + return new Adbc.Client.AdbcConnection( new BigQueryDriver(), BigQueryTestingUtils.GetBigQueryParameters(testConfiguration), new Dictionary() diff --git a/csharp/test/Drivers/BigQuery/DriverTests.cs b/csharp/test/Drivers/BigQuery/DriverTests.cs index 0cfde12cab..9d470475b6 100644 --- a/csharp/test/Drivers/BigQuery/DriverTests.cs +++ b/csharp/test/Drivers/BigQuery/DriverTests.cs @@ -19,9 +19,9 @@ using System.Collections.Generic; using System.Linq; using Apache.Arrow.Adbc.Tests.Metadata; +using Apache.Arrow.Adbc.Tests.Xunit; using Apache.Arrow.Ipc; -using NUnit.Framework; -using NUnit.Framework.Internal; +using Xunit; namespace Apache.Arrow.Adbc.Tests.Drivers.BigQuery { @@ -32,196 +32,183 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.BigQuery /// Tests are ordered to ensure data is created for the other /// queries to run. /// - [TestFixture] + [TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer", "Apache.Arrow.Adbc.Tests")] public class DriverTests { + public DriverTests() + { + Skip.IfNot(Utils.CanExecuteTestConfig(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE)); + } + /// /// Validates if the driver can connect to a live server and /// parse the results. /// - [Test, Order(1)] + [SkippableFact, Order(1)] public void CanExecuteUpdate() { - if (Utils.CanExecuteTestConfig(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE)) - { - BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); + BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); - AdbcConnection adbcConnection = BigQueryTestingUtils.GetBigQueryAdbcConnection(testConfiguration); + AdbcConnection adbcConnection = BigQueryTestingUtils.GetBigQueryAdbcConnection(testConfiguration); - string[] queries = BigQueryTestingUtils.GetQueries(testConfiguration); + string[] queries = BigQueryTestingUtils.GetQueries(testConfiguration); - List expectedResults = new List() { -1, 1, 1, 1 }; + List expectedResults = new List() { -1, 1, 1, 1 }; - for (int i = 0; i < queries.Length; i++) - { - string query = queries[i]; - AdbcStatement statement = adbcConnection.CreateStatement(); - statement.SqlQuery = query; + for (int i = 0; i < queries.Length; i++) + { + string query = queries[i]; + AdbcStatement statement = adbcConnection.CreateStatement(); + statement.SqlQuery = query; - UpdateResult updateResult = statement.ExecuteUpdate(); + UpdateResult updateResult = statement.ExecuteUpdate(); - Assert.AreEqual(expectedResults[i], updateResult.AffectedRows, $"The expected affected rows do not match the actual affected rows at position {i}."); - } + Assert.Equal(expectedResults[i], updateResult.AffectedRows); } } /// /// Validates if the driver can call GetInfo. /// - [Test, Order(2)] + [SkippableFact, Order(2)] public void CanGetInfo() { - if (Utils.CanExecuteTestConfig(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE)) - { - BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); + BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); - AdbcConnection adbcConnection = BigQueryTestingUtils.GetBigQueryAdbcConnection(testConfiguration); + AdbcConnection adbcConnection = BigQueryTestingUtils.GetBigQueryAdbcConnection(testConfiguration); - IArrowArrayStream stream = adbcConnection.GetInfo(new List() { AdbcInfoCode.DriverName, AdbcInfoCode.DriverVersion, AdbcInfoCode.VendorName }); + IArrowArrayStream stream = adbcConnection.GetInfo(new List() { AdbcInfoCode.DriverName, AdbcInfoCode.DriverVersion, AdbcInfoCode.VendorName }); - RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - UInt32Array infoNameArray = (UInt32Array)recordBatch.Column("info_name"); + RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; + UInt32Array infoNameArray = (UInt32Array)recordBatch.Column("info_name"); - List expectedValues = new List() { "DriverName", "DriverVersion", "VendorName" }; + List expectedValues = new List() { "DriverName", "DriverVersion", "VendorName" }; - for (int i = 0; i < infoNameArray.Length; i++) - { - AdbcInfoCode value = (AdbcInfoCode)infoNameArray.GetValue(i); - DenseUnionArray valueArray = (DenseUnionArray)recordBatch.Column("info_value"); + for (int i = 0; i < infoNameArray.Length; i++) + { + AdbcInfoCode value = (AdbcInfoCode)infoNameArray.GetValue(i); + DenseUnionArray valueArray = (DenseUnionArray)recordBatch.Column("info_value"); - Assert.IsTrue(expectedValues.Contains(value.ToString())); + Assert.Contains(value.ToString(), expectedValues); - StringArray stringArray = (StringArray)valueArray.Fields[0]; - Console.WriteLine($"{value}={stringArray.GetString(i)}"); - } + StringArray stringArray = (StringArray)valueArray.Fields[0]; + Console.WriteLine($"{value}={stringArray.GetString(i)}"); } } /// /// Validates if the driver can call GetObjects. /// - [Test, Order(3)] + [SkippableFact, Order(3)] public void CanGetObjects() { - if (Utils.CanExecuteTestConfig(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE)) - { - BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); + BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); - // need to add the database - string catalogName = testConfiguration.Metadata.Catalog; - string schemaName = testConfiguration.Metadata.Schema; - string tableName = testConfiguration.Metadata.Table; - string columnName = null; + // need to add the database + string catalogName = testConfiguration.Metadata.Catalog; + string schemaName = testConfiguration.Metadata.Schema; + string tableName = testConfiguration.Metadata.Table; + string columnName = null; - AdbcConnection adbcConnection = BigQueryTestingUtils.GetBigQueryAdbcConnection(testConfiguration); + AdbcConnection adbcConnection = BigQueryTestingUtils.GetBigQueryAdbcConnection(testConfiguration); - IArrowArrayStream stream = adbcConnection.GetObjects( - depth: AdbcConnection.GetObjectsDepth.All, - catalogPattern: catalogName, - dbSchemaPattern: schemaName, - tableNamePattern: tableName, - tableTypes: new List { "BASE TABLE", "VIEW" }, - columnNamePattern: columnName); + IArrowArrayStream stream = adbcConnection.GetObjects( + depth: AdbcConnection.GetObjectsDepth.All, + catalogPattern: catalogName, + dbSchemaPattern: schemaName, + tableNamePattern: tableName, + tableTypes: new List { "BASE TABLE", "VIEW" }, + columnNamePattern: columnName); - RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; + RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - List catalogs = GetObjectsParser.ParseCatalog(recordBatch, catalogName, schemaName); + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, catalogName, schemaName); - List columns = catalogs - .Select(s => s.DbSchemas) - .FirstOrDefault() - .Select(t => t.Tables) - .FirstOrDefault() - .Select(c => c.Columns) - .FirstOrDefault(); + List columns = catalogs + .Select(s => s.DbSchemas) + .FirstOrDefault() + .Select(t => t.Tables) + .FirstOrDefault() + .Select(c => c.Columns) + .FirstOrDefault(); - Assert.AreEqual(testConfiguration.Metadata.ExpectedColumnCount, columns.Count); - } + Assert.Equal(testConfiguration.Metadata.ExpectedColumnCount, columns.Count); } /// /// Validates if the driver can call GetTableSchema. /// - [Test, Order(4)] + [SkippableFact, Order(4)] public void CanGetTableSchema() { - if (Utils.CanExecuteTestConfig(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE)) - { - BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); + BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); - AdbcConnection adbcConnection = BigQueryTestingUtils.GetBigQueryAdbcConnection(testConfiguration); + AdbcConnection adbcConnection = BigQueryTestingUtils.GetBigQueryAdbcConnection(testConfiguration); - string catalogName = testConfiguration.Metadata.Catalog; - string schemaName = testConfiguration.Metadata.Schema; - string tableName = testConfiguration.Metadata.Table; + string catalogName = testConfiguration.Metadata.Catalog; + string schemaName = testConfiguration.Metadata.Schema; + string tableName = testConfiguration.Metadata.Table; - Schema schema = adbcConnection.GetTableSchema(catalogName, schemaName, tableName); + Schema schema = adbcConnection.GetTableSchema(catalogName, schemaName, tableName); - int numberOfFields = schema.FieldsList.Count; + int numberOfFields = schema.FieldsList.Count; - Assert.AreEqual(testConfiguration.Metadata.ExpectedColumnCount, numberOfFields); - } + Assert.Equal(testConfiguration.Metadata.ExpectedColumnCount, numberOfFields); } /// /// Validates if the driver can call GetTableTypes. /// - [Test, Order(5)] + [SkippableFact, Order(5)] public void CanGetTableTypes() { - if (Utils.CanExecuteTestConfig(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE)) - { - BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); + BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); - AdbcConnection adbcConnection = BigQueryTestingUtils.GetBigQueryAdbcConnection(testConfiguration); + AdbcConnection adbcConnection = BigQueryTestingUtils.GetBigQueryAdbcConnection(testConfiguration); - IArrowArrayStream arrowArrayStream = adbcConnection.GetTableTypes(); + IArrowArrayStream arrowArrayStream = adbcConnection.GetTableTypes(); - RecordBatch recordBatch = arrowArrayStream.ReadNextRecordBatchAsync().Result; + RecordBatch recordBatch = arrowArrayStream.ReadNextRecordBatchAsync().Result; - StringArray stringArray = (StringArray)recordBatch.Column("table_type"); + StringArray stringArray = (StringArray)recordBatch.Column("table_type"); - List known_types = new List - { - "BASE TABLE", "VIEW" - }; + List known_types = new List + { + "BASE TABLE", "VIEW" + }; - int results = 0; + int results = 0; - for (int i = 0; i < stringArray.Length; i++) - { - string value = stringArray.GetString(i); + for (int i = 0; i < stringArray.Length; i++) + { + string value = stringArray.GetString(i); - if (known_types.Contains(value)) - { - results++; - } + if (known_types.Contains(value)) + { + results++; } - - Assert.AreEqual(known_types.Count, results); } + + Assert.Equal(known_types.Count, results); } /// /// Validates if the driver can connect to a live server and /// parse the results. /// - [Test, Order(6)] + [SkippableFact, Order(6)] public void CanExecuteQuery() { - if (Utils.CanExecuteTestConfig(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE)) - { - BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); + BigQueryTestConfiguration testConfiguration = Utils.LoadTestConfiguration(BigQueryTestingUtils.BIGQUERY_TEST_CONFIG_VARIABLE); - AdbcConnection adbcConnection = BigQueryTestingUtils.GetBigQueryAdbcConnection(testConfiguration); + AdbcConnection adbcConnection = BigQueryTestingUtils.GetBigQueryAdbcConnection(testConfiguration); - AdbcStatement statement = adbcConnection.CreateStatement(); - statement.SqlQuery = testConfiguration.Query; + AdbcStatement statement = adbcConnection.CreateStatement(); + statement.SqlQuery = testConfiguration.Query; - QueryResult queryResult = statement.ExecuteQuery(); + QueryResult queryResult = statement.ExecuteQuery(); - Tests.DriverTests.CanExecuteQuery(queryResult, testConfiguration.ExpectedResultsCount); - } + Tests.DriverTests.CanExecuteQuery(queryResult, testConfiguration.ExpectedResultsCount); } } } diff --git a/csharp/test/Drivers/BigQuery/SampleData.cs b/csharp/test/Drivers/BigQuery/SampleData.cs index 172331079d..d97e51dcc1 100644 --- a/csharp/test/Drivers/BigQuery/SampleData.cs +++ b/csharp/test/Drivers/BigQuery/SampleData.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; +using System.Data.SqlTypes; using System.Text; using Apache.Arrow.Types; @@ -40,8 +41,8 @@ public static List GetSampleData() { new ColumnNetTypeArrowTypeValue("id", typeof(long), typeof(Int64Type), 1L), new ColumnNetTypeArrowTypeValue("number", typeof(double), typeof(DoubleType), 1.23d), - new ColumnNetTypeArrowTypeValue("decimal", typeof(decimal), typeof(Decimal128Type), decimal.Parse("4.56")), - new ColumnNetTypeArrowTypeValue("big_decimal", typeof(string), typeof(StringType), "789000000000000000000000000000000000000"), + new ColumnNetTypeArrowTypeValue("decimal", typeof(SqlDecimal), typeof(Decimal128Type), SqlDecimal.Parse("4.56")), + new ColumnNetTypeArrowTypeValue("big_decimal", typeof(string), typeof(StringType), "7.89000000000000000000000000000000000000"), new ColumnNetTypeArrowTypeValue("is_active", typeof(bool), typeof(BooleanType), true), new ColumnNetTypeArrowTypeValue("name", typeof(string), typeof(StringType), "John Doe"), new ColumnNetTypeArrowTypeValue("data", typeof(byte[]), typeof(BinaryType), UTF8Encoding.UTF8.GetBytes("abc123")), diff --git a/csharp/test/Drivers/FlightSql/Apache.Arrow.Adbc.Tests.Drivers.FlightSql.csproj b/csharp/test/Drivers/FlightSql/Apache.Arrow.Adbc.Tests.Drivers.FlightSql.csproj index ab57649335..7647af2d76 100644 --- a/csharp/test/Drivers/FlightSql/Apache.Arrow.Adbc.Tests.Drivers.FlightSql.csproj +++ b/csharp/test/Drivers/FlightSql/Apache.Arrow.Adbc.Tests.Drivers.FlightSql.csproj @@ -5,11 +5,13 @@ False - - - - - + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + diff --git a/csharp/test/Drivers/FlightSql/ClientTests.cs b/csharp/test/Drivers/FlightSql/ClientTests.cs index 765e6371c9..04f3c6f512 100644 --- a/csharp/test/Drivers/FlightSql/ClientTests.cs +++ b/csharp/test/Drivers/FlightSql/ClientTests.cs @@ -15,28 +15,34 @@ * limitations under the License. */ +using System; using System.Collections.Generic; using Apache.Arrow.Adbc.Client; using Apache.Arrow.Adbc.Drivers.FlightSql; -using Microsoft.VisualStudio.TestTools.UnitTesting; +using Xunit; namespace Apache.Arrow.Adbc.Tests.Drivers.FlightSql { - [TestClass] + /// + /// Class for testing the ADBC Client using the FlightSql ADBC driver. + /// public class ClientTests { + public ClientTests() + { + Skip.IfNot(Utils.CanExecuteTestConfig(FlightSqlTestingUtils.FLIGHTSQL_TEST_CONFIG_VARIABLE)); + } + /// /// Validates if the client can connect to a live server and /// parse the results. /// - [TestMethod] + [SkippableFact] public void CanFlightSqlConnectUsingClient() { - if (Utils.CanExecuteTestConfig(FlightSqlTestingUtils.FLIGHTSQL_TEST_CONFIG_VARIABLE)) - { - FlightSqlTestConfiguration flightSqlTestConfiguration = Utils.LoadTestConfiguration(FlightSqlTestingUtils.FLIGHTSQL_TEST_CONFIG_VARIABLE); + FlightSqlTestConfiguration flightSqlTestConfiguration = Utils.LoadTestConfiguration(FlightSqlTestingUtils.FLIGHTSQL_TEST_CONFIG_VARIABLE); - Dictionary parameters = new Dictionary + Dictionary parameters = new Dictionary { { FlightSqlParameters.ServerAddress, flightSqlTestConfiguration.ServerAddress }, { FlightSqlParameters.RoutingTag, flightSqlTestConfiguration.RoutingTag }, @@ -44,39 +50,48 @@ public void CanFlightSqlConnectUsingClient() { FlightSqlParameters.Authorization, flightSqlTestConfiguration.Authorization} }; - Dictionary options = new Dictionary() + Dictionary options = new Dictionary() { { FlightSqlParameters.ServerAddress, flightSqlTestConfiguration.ServerAddress }, }; - long count = 0; + long count = 0; - using (Client.AdbcConnection adbcConnection = new Client.AdbcConnection( - new FlightSqlDriver(), - parameters, - options) - ) - { - string query = flightSqlTestConfiguration.Query; + using (Adbc.Client.AdbcConnection adbcConnection = new Adbc.Client.AdbcConnection( + new FlightSqlDriver(), + parameters, + options) + ) + { + string query = flightSqlTestConfiguration.Query; - AdbcCommand adbcCommand = new AdbcCommand(query, adbcConnection); + AdbcCommand adbcCommand = new AdbcCommand(query, adbcConnection); - adbcConnection.Open(); + adbcConnection.Open(); - AdbcDataReader reader = adbcCommand.ExecuteReader(); + AdbcDataReader reader = adbcCommand.ExecuteReader(); - try + try + { + while (reader.Read()) { - while (reader.Read()) + count++; + + for (int i = 0; i < reader.FieldCount; i++) { - count++; + object value = reader.GetValue(i); + + if (value == null) + value = "(null)"; + + Console.WriteLine($"{reader.GetName(i)}: {value}"); } } - finally { reader.Close(); } } - - Assert.AreEqual(flightSqlTestConfiguration.ExpectedResultsCount, count); + finally { reader.Close(); } } + + Assert.Equal(flightSqlTestConfiguration.ExpectedResultsCount, count); } } } diff --git a/csharp/test/Drivers/FlightSql/DriverTests.cs b/csharp/test/Drivers/FlightSql/DriverTests.cs index a591f5a9b7..db3d3a585d 100644 --- a/csharp/test/Drivers/FlightSql/DriverTests.cs +++ b/csharp/test/Drivers/FlightSql/DriverTests.cs @@ -15,32 +15,31 @@ * limitations under the License. */ -using System; using System.Collections.Generic; using Apache.Arrow.Adbc.Drivers.FlightSql; -using Microsoft.VisualStudio.TestTools.UnitTesting; -using Moq; +using Xunit; namespace Apache.Arrow.Adbc.Tests.Drivers.FlightSql { /// - /// Abstract class for the ADBC connection tests. + /// Class for testing the FlightSql ADBC driver connection tests. /// - [TestClass] public class DriverTests { + public DriverTests() + { + Skip.IfNot(Utils.CanExecuteTestConfig(FlightSqlTestingUtils.FLIGHTSQL_TEST_CONFIG_VARIABLE)); + } + /// /// Validates if the driver can connect to a live server and /// parse the results. /// - [TestMethod] public void CanDriverExecuteQuery() { - if (Utils.CanExecuteTestConfig(FlightSqlTestingUtils.FLIGHTSQL_TEST_CONFIG_VARIABLE)) - { - FlightSqlTestConfiguration flightSqlTestConfiguration = Utils.LoadTestConfiguration(FlightSqlTestingUtils.FLIGHTSQL_TEST_CONFIG_VARIABLE); + FlightSqlTestConfiguration flightSqlTestConfiguration = Utils.LoadTestConfiguration(FlightSqlTestingUtils.FLIGHTSQL_TEST_CONFIG_VARIABLE); - Dictionary parameters = new Dictionary + Dictionary parameters = new Dictionary { { FlightSqlParameters.ServerAddress, flightSqlTestConfiguration.ServerAddress }, { FlightSqlParameters.RoutingTag, flightSqlTestConfiguration.RoutingTag }, @@ -48,21 +47,20 @@ public void CanDriverExecuteQuery() { FlightSqlParameters.Authorization, flightSqlTestConfiguration.Authorization} }; - Dictionary options = new Dictionary() + Dictionary options = new Dictionary() { { FlightSqlParameters.ServerAddress, flightSqlTestConfiguration.ServerAddress }, }; - FlightSqlDriver flightSqlDriver = new FlightSqlDriver(); - FlightSqlDatabase flightSqlDatabase = flightSqlDriver.Open(parameters) as FlightSqlDatabase; - FlightSqlConnection connection = flightSqlDatabase.Connect(options) as FlightSqlConnection; - FlightSqlStatement statement = connection.CreateStatement() as FlightSqlStatement; + FlightSqlDriver flightSqlDriver = new FlightSqlDriver(); + FlightSqlDatabase flightSqlDatabase = flightSqlDriver.Open(parameters) as FlightSqlDatabase; + FlightSqlConnection connection = flightSqlDatabase.Connect(options) as FlightSqlConnection; + FlightSqlStatement statement = connection.CreateStatement() as FlightSqlStatement; - statement.SqlQuery = flightSqlTestConfiguration.Query; - QueryResult queryResult = statement.ExecuteQuery(); + statement.SqlQuery = flightSqlTestConfiguration.Query; + QueryResult queryResult = statement.ExecuteQuery(); - Tests.DriverTests.CanExecuteQuery(queryResult, flightSqlTestConfiguration.ExpectedResultsCount); - } + Tests.DriverTests.CanExecuteQuery(queryResult, flightSqlTestConfiguration.ExpectedResultsCount); } } } diff --git a/csharp/test/Drivers/Snowflake/Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake.csproj b/csharp/test/Drivers/Snowflake/Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake.csproj index d7ade6c84e..6fa40d80b0 100644 --- a/csharp/test/Drivers/Snowflake/Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake.csproj +++ b/csharp/test/Drivers/Snowflake/Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake.csproj @@ -3,12 +3,13 @@ net472;net6.0 - - - - - - + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + diff --git a/csharp/test/Drivers/Snowflake/ClientTests.cs b/csharp/test/Drivers/Snowflake/ClientTests.cs index 8dcb9deea0..99c01b8e27 100644 --- a/csharp/test/Drivers/Snowflake/ClientTests.cs +++ b/csharp/test/Drivers/Snowflake/ClientTests.cs @@ -20,9 +20,11 @@ using System.Collections.ObjectModel; using System.Data; using System.Data.Common; +using System.Data.SqlTypes; using System.IO; using Apache.Arrow.Adbc.Client; -using NUnit.Framework; +using Apache.Arrow.Adbc.Tests.Xunit; +using Xunit; namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake { @@ -33,37 +35,39 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake /// Tests are ordered to ensure data is created /// for the other queries to run. /// - [TestFixture] + [TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer", "Apache.Arrow.Adbc.Tests")] public class ClientTests { + public ClientTests() + { + Skip.IfNot(Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE)); + } + /// /// Validates if the client execute updates. /// - [Test, Order(1)] + [SkippableFact, Order(1)] public void CanClientExecuteUpdate() { - if (Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE)) - { - SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); + SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - using (Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnection(testConfiguration)) - { - adbcConnection.Open(); + using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnection(testConfiguration)) + { + adbcConnection.Open(); - string[] queries = SnowflakeTestingUtils.GetQueries(testConfiguration); + string[] queries = SnowflakeTestingUtils.GetQueries(testConfiguration); - List expectedResults = new List() { -1, 1, 1 }; + List expectedResults = new List() { -1, 1, 1 }; - for (int i = 0; i < queries.Length; i++) - { - string query = queries[i]; - AdbcCommand adbcCommand = adbcConnection.CreateCommand(); - adbcCommand.CommandText = query; + for (int i = 0; i < queries.Length; i++) + { + string query = queries[i]; + AdbcCommand adbcCommand = adbcConnection.CreateCommand(); + adbcCommand.CommandText = query; - int rows = adbcCommand.ExecuteNonQuery(); + int rows = adbcCommand.ExecuteNonQuery(); - Assert.AreEqual(expectedResults[i], rows, $"The expected affected rows do not match the actual affected rows at position {i}."); - } + Assert.Equal(expectedResults[i], rows); } } } @@ -71,62 +75,59 @@ public void CanClientExecuteUpdate() /// /// Validates if the client execute updates using the reader. /// - [Test, Order(2)] + [SkippableFact, Order(2)] public void CanClientExecuteUpdateUsingExecuteReader() { - if (Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE)) - { - SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); + SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - using (Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnection(testConfiguration)) - { - adbcConnection.Open(); + using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnection(testConfiguration)) + { + adbcConnection.Open(); - string[] queries = SnowflakeTestingUtils.GetQueries(testConfiguration); + string[] queries = SnowflakeTestingUtils.GetQueries(testConfiguration); - List expectedResults = new List() { string.Format("Table {0} successfully created.",testConfiguration.Metadata.Table), 1L, 1L }; + List expectedResults = new List() { $"Table {testConfiguration.Metadata.Table} successfully created.", new SqlDecimal(1L), new SqlDecimal(1L) }; - for (int i = 0; i < queries.Length; i++) - { - string query = queries[i]; - AdbcCommand adbcCommand = adbcConnection.CreateCommand(); - adbcCommand.CommandText = query; + for (int i = 0; i < queries.Length; i++) + { + string query = queries[i]; + AdbcCommand adbcCommand = adbcConnection.CreateCommand(); + adbcCommand.CommandText = query; - AdbcDataReader reader = adbcCommand.ExecuteReader(CommandBehavior.Default); + AdbcDataReader reader = adbcCommand.ExecuteReader(CommandBehavior.Default); - if (reader.Read()) - { - Assert.AreEqual(expectedResults[i], reader.GetValue(0), $"The expected affected rows do not match the actual affected rows at position {i}."); - } - else - { - Assert.Fail("Could not read the records"); - } + if (reader.Read()) + { + Assert.True(expectedResults[i].Equals(reader.GetValue(0)), $"The expected affected rows do not match the actual affected rows at position {i}."); + } + else + { + Assert.Fail("Could not read the records"); } } } } - [Test, Order(2)] + /// + /// Validates if the client can get the schema. + /// + [SkippableFact, Order(3)] public void CanClientGetSchema() { - if (Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE)) - { - SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); + SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - using (Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnection(testConfiguration)) - { - AdbcCommand adbcCommand = new AdbcCommand(testConfiguration.Query, adbcConnection); + using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnection(testConfiguration)) + { + AdbcCommand adbcCommand = new AdbcCommand(testConfiguration.Query, adbcConnection); - adbcConnection.Open(); + adbcConnection.Open(); - AdbcDataReader reader = adbcCommand.ExecuteReader(CommandBehavior.SchemaOnly); + AdbcDataReader reader = adbcCommand.ExecuteReader(CommandBehavior.SchemaOnly); - DataTable table = reader.GetSchemaTable(); + DataTable table = reader.GetSchemaTable(); - // there is one row per field - Assert.AreEqual(testConfiguration.Metadata.ExpectedColumnCount, table.Rows.Count); - } + // there is one row per field + Assert.Equal(testConfiguration.Metadata.ExpectedColumnCount, table.Rows.Count); } } @@ -134,123 +135,129 @@ public void CanClientGetSchema() /// Validates if the client can connect to a live server /// and parse the results. /// - [Test, Order(3)] + [SkippableFact, Order(4)] public void CanClientExecuteQuery() { - if (Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE)) - { - SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); + SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - long count = 0; + long count = 0; - using (Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnection(testConfiguration)) - { - AdbcCommand adbcCommand = new AdbcCommand(testConfiguration.Query, adbcConnection); + using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnection(testConfiguration)) + { + AdbcCommand adbcCommand = new AdbcCommand(testConfiguration.Query, adbcConnection); - adbcConnection.Open(); + adbcConnection.Open(); - AdbcDataReader reader = adbcCommand.ExecuteReader(); + AdbcDataReader reader = adbcCommand.ExecuteReader(); - try + try + { + while (reader.Read()) { - while (reader.Read()) + count++; + + for (int i = 0; i < reader.FieldCount; i++) { - count++; + object value = reader.GetValue(i); - for (int i = 0; i < reader.FieldCount; i++) - { - Console.WriteLine($"{reader.GetName(i)}: {reader.GetValue(i)}"); - } + if (value == null) + value = "(null)"; + + Console.WriteLine($"{reader.GetName(i)}: {value}"); } } - finally { reader.Close(); } } - - Assert.AreEqual(testConfiguration.ExpectedResultsCount, count); + finally { reader.Close(); } } + + Assert.Equal(testConfiguration.ExpectedResultsCount, count); } /// /// Validates if the client can connect to a live server /// using a connection string / private key and parse the results. /// - [Test, Order(4)] + [SkippableFact, Order(5)] public void CanClientExecuteQueryUsingPrivateKey() { - if (Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE)) - { - SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); + SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - long count = 0; + long count = 0; - using (Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration)) - { - AdbcCommand adbcCommand = new AdbcCommand(testConfiguration.Query, adbcConnection); + using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration)) + { + AdbcCommand adbcCommand = new AdbcCommand(testConfiguration.Query, adbcConnection); - adbcConnection.Open(); + adbcConnection.Open(); - AdbcDataReader reader = adbcCommand.ExecuteReader(); + AdbcDataReader reader = adbcCommand.ExecuteReader(); - try + try + { + while (reader.Read()) { - while (reader.Read()) + count++; + + for (int i = 0; i < reader.FieldCount; i++) { - count++; + object value = reader.GetValue(i); + + if (value == null) + value = "(null)"; + + Console.WriteLine($"{reader.GetName(i)}: {value}"); } } - finally { reader.Close(); } } - - Assert.AreEqual(testConfiguration.ExpectedResultsCount, count); + finally { reader.Close(); } } + + Assert.Equal(testConfiguration.ExpectedResultsCount, count); } /// /// Validates if the client is retrieving and converting values /// to the expected types. /// - [Test, Order(5)] + [SkippableFact, Order(6)] public void VerifyTypesAndValues() { - if (Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE)) - { - SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); + SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - Client.AdbcConnection dbConnection = GetSnowflakeAdbcConnection(testConfiguration); - dbConnection.Open(); + Adbc.Client.AdbcConnection dbConnection = GetSnowflakeAdbcConnection(testConfiguration); + dbConnection.Open(); - DbCommand dbCommand = dbConnection.CreateCommand(); - dbCommand.CommandText = testConfiguration.Query; + DbCommand dbCommand = dbConnection.CreateCommand(); + dbCommand.CommandText = testConfiguration.Query; - DbDataReader reader = dbCommand.ExecuteReader(CommandBehavior.Default); + DbDataReader reader = dbCommand.ExecuteReader(CommandBehavior.Default); - if (reader.Read()) - { - ReadOnlyCollection column_schema = reader.GetColumnSchema(); + if (reader.Read()) + { + ReadOnlyCollection column_schema = reader.GetColumnSchema(); - DataTable dataTable = reader.GetSchemaTable(); + DataTable dataTable = reader.GetSchemaTable(); - List expectedValues = SampleData.GetSampleData(); + List expectedValues = SampleData.GetSampleData(); - for (int i = 0; i < reader.FieldCount; i++) - { - object value = reader.GetValue(i); - ColumnNetTypeArrowTypeValue ctv = expectedValues[i]; + for (int i = 0; i < reader.FieldCount; i++) + { + object value = reader.GetValue(i); + ColumnNetTypeArrowTypeValue ctv = expectedValues[i]; - string readerColumnName = reader.GetName(i); - string dataTableColumnName = dataTable.Rows[i][SchemaTableColumn.ColumnName].ToString(); + string readerColumnName = reader.GetName(i); + string dataTableColumnName = dataTable.Rows[i][SchemaTableColumn.ColumnName].ToString(); - Assert.IsTrue(readerColumnName.Equals(ctv.Name, StringComparison.OrdinalIgnoreCase), $"`{readerColumnName}` != `{ctv.Name}` at position {i}. Verify the test query and sample data return in the same order in the reader."); + Assert.True(readerColumnName.Equals(ctv.Name, StringComparison.OrdinalIgnoreCase), $"`{readerColumnName}` != `{ctv.Name}` at position {i}. Verify the test query and sample data return in the same order in the reader."); - Assert.IsTrue(dataTableColumnName.Equals(ctv.Name, StringComparison.OrdinalIgnoreCase), $"`{dataTableColumnName}` != `{ctv.Name}` at position {i}. Verify the test query and sample data return in the same order in the data table."); + Assert.True(dataTableColumnName.Equals(ctv.Name, StringComparison.OrdinalIgnoreCase), $"`{dataTableColumnName}` != `{ctv.Name}` at position {i}. Verify the test query and sample data return in the same order in the data table."); - Tests.ClientTests.AssertTypeAndValue(ctv, value, reader, column_schema, dataTable); - } + Tests.ClientTests.AssertTypeAndValue(ctv, value, reader, column_schema, dataTable); } } } - private Client.AdbcConnection GetSnowflakeAdbcConnectionUsingConnectionString(SnowflakeTestConfiguration testConfiguration) + private Adbc.Client.AdbcConnection GetSnowflakeAdbcConnectionUsingConnectionString(SnowflakeTestConfiguration testConfiguration) { // see https://arrow.apache.org/adbc/0.5.1/driver/snowflake.html @@ -285,19 +292,19 @@ private Client.AdbcConnection GetSnowflakeAdbcConnectionUsingConnectionString(Sn AdbcDriver snowflakeDriver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(testConfiguration); - return new Client.AdbcConnection(builder.ConnectionString) + return new Adbc.Client.AdbcConnection(builder.ConnectionString) { AdbcDriver = snowflakeDriver }; } - private Client.AdbcConnection GetSnowflakeAdbcConnection(SnowflakeTestConfiguration testConfiguration) + private Adbc.Client.AdbcConnection GetSnowflakeAdbcConnection(SnowflakeTestConfiguration testConfiguration) { Dictionary parameters = new Dictionary(); AdbcDriver snowflakeDriver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(testConfiguration, out parameters); - Client.AdbcConnection adbcConnection = new Client.AdbcConnection( + Adbc.Client.AdbcConnection adbcConnection = new Adbc.Client.AdbcConnection( snowflakeDriver, parameters: parameters, options: new Dictionary() diff --git a/csharp/test/Drivers/Snowflake/DriverTests.cs b/csharp/test/Drivers/Snowflake/DriverTests.cs index 41527e99e2..0312fff958 100644 --- a/csharp/test/Drivers/Snowflake/DriverTests.cs +++ b/csharp/test/Drivers/Snowflake/DriverTests.cs @@ -19,9 +19,9 @@ using System.Collections.Generic; using System.Linq; using Apache.Arrow.Adbc.Tests.Metadata; +using Apache.Arrow.Adbc.Tests.Xunit; using Apache.Arrow.Ipc; -using NUnit.Framework; -using NUnit.Framework.Internal; +using Xunit; namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake { @@ -32,9 +32,14 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Interop.Snowflake /// Tests are ordered to ensure data is created for the other /// queries to run. /// - [TestFixture] + [TestCaseOrderer("Apache.Arrow.Adbc.Tests.Xunit.TestOrderer", "Apache.Arrow.Adbc.Tests")] public class DriverTests { + public DriverTests() + { + Skip.IfNot(Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE)); + } + /// /// Validates if the driver can connect to a live server and /// parse the results. @@ -43,228 +48,210 @@ public class DriverTests /// Tests are ordered to ensure data is created /// for the other queries to run. /// - [Test, Order(1)] + [SkippableFact, Order(1)] public void CanExecuteUpdate() { - if (Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE)) - { - SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); + SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - Dictionary parameters = new Dictionary(); - Dictionary options = new Dictionary(); + Dictionary parameters = new Dictionary(); + Dictionary options = new Dictionary(); - AdbcDriver snowflakeDriver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(testConfiguration, out parameters); + AdbcDriver snowflakeDriver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(testConfiguration, out parameters); - AdbcDatabase adbcDatabase = snowflakeDriver.Open(parameters); - AdbcConnection adbcConnection = adbcDatabase.Connect(options); + AdbcDatabase adbcDatabase = snowflakeDriver.Open(parameters); + AdbcConnection adbcConnection = adbcDatabase.Connect(options); - string[] queries = SnowflakeTestingUtils.GetQueries(testConfiguration); + string[] queries = SnowflakeTestingUtils.GetQueries(testConfiguration); - List expectedResults = new List() { -1, 1, 1 }; + List expectedResults = new List() { -1, 1, 1 }; - for (int i = 0; i < queries.Length; i++) - { - string query = queries[i]; - AdbcStatement statement = adbcConnection.CreateStatement(); - statement.SqlQuery = query; + for (int i = 0; i < queries.Length; i++) + { + string query = queries[i]; + AdbcStatement statement = adbcConnection.CreateStatement(); + statement.SqlQuery = query; - UpdateResult updateResult = statement.ExecuteUpdate(); + UpdateResult updateResult = statement.ExecuteUpdate(); - Assert.AreEqual(expectedResults[i], updateResult.AffectedRows, $"The expected affected rows do not match the actual affected rows at position {i}."); - } + Assert.Equal(expectedResults[i], updateResult.AffectedRows); } } /// /// Validates if the driver can call GetInfo. /// - [Test, Order(2)] + [SkippableFact, Order(2)] public void CanGetInfo() { - if (Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE)) - { - Dictionary parameters = new Dictionary(); + Dictionary parameters = new Dictionary(); - SnowflakeTestConfiguration metadataTestConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); + SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - AdbcDriver driver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(metadataTestConfiguration, out parameters); + AdbcDriver driver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(testConfiguration, out parameters); - AdbcDatabase adbcDatabase = driver.Open(parameters); - AdbcConnection adbcConnection = adbcDatabase.Connect(new Dictionary()); + AdbcDatabase adbcDatabase = driver.Open(parameters); + AdbcConnection adbcConnection = adbcDatabase.Connect(new Dictionary()); - IArrowArrayStream stream = adbcConnection.GetInfo(new List() { AdbcInfoCode.DriverName, AdbcInfoCode.DriverVersion, AdbcInfoCode.VendorName }); + IArrowArrayStream stream = adbcConnection.GetInfo(new List() { AdbcInfoCode.DriverName, AdbcInfoCode.DriverVersion, AdbcInfoCode.VendorName }); - RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - UInt32Array infoNameArray = (UInt32Array)recordBatch.Column("info_name"); + RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; + UInt32Array infoNameArray = (UInt32Array)recordBatch.Column("info_name"); - List expectedValues = new List() { "DriverName", "DriverVersion", "VendorName" }; + List expectedValues = new List() { "DriverName", "DriverVersion", "VendorName" }; - for (int i = 0; i < infoNameArray.Length; i++) - { - AdbcInfoCode value = (AdbcInfoCode)infoNameArray.GetValue(i); - DenseUnionArray valueArray = (DenseUnionArray)recordBatch.Column("info_value"); + for (int i = 0; i < infoNameArray.Length; i++) + { + AdbcInfoCode value = (AdbcInfoCode)infoNameArray.GetValue(i); + DenseUnionArray valueArray = (DenseUnionArray)recordBatch.Column("info_value"); - Assert.IsTrue(expectedValues.Contains(value.ToString())); + Assert.Contains(value.ToString(), expectedValues); - StringArray stringArray = (StringArray)valueArray.Fields[0]; - Console.WriteLine($"{value}={stringArray.GetString(i)}"); - } + StringArray stringArray = (StringArray)valueArray.Fields[0]; + Console.WriteLine($"{value}={stringArray.GetString(i)}"); } } /// /// Validates if the driver can call GetObjects. /// - [Test, Order(3)] + [SkippableFact, Order(3)] public void CanGetObjects() { - if (Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE)) - { - Dictionary parameters = new Dictionary(); + Dictionary parameters = new Dictionary(); - SnowflakeTestConfiguration metadataTestConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); + SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - AdbcDriver driver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(metadataTestConfiguration, out parameters); + AdbcDriver driver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(testConfiguration, out parameters); - // need to add the database - string databaseName = metadataTestConfiguration.Metadata.Catalog; - string schemaName = metadataTestConfiguration.Metadata.Schema; - string tableName = metadataTestConfiguration.Metadata.Table; - string columnName = null; + // need to add the database + string databaseName = testConfiguration.Metadata.Catalog; + string schemaName = testConfiguration.Metadata.Schema; + string tableName = testConfiguration.Metadata.Table; + string columnName = null; - parameters[SnowflakeParameters.DATABASE] = databaseName; - parameters[SnowflakeParameters.SCHEMA] = schemaName; + parameters[SnowflakeParameters.DATABASE] = databaseName; + parameters[SnowflakeParameters.SCHEMA] = schemaName; - AdbcDatabase adbcDatabase = driver.Open(parameters); - AdbcConnection adbcConnection = adbcDatabase.Connect(new Dictionary()); + AdbcDatabase adbcDatabase = driver.Open(parameters); + AdbcConnection adbcConnection = adbcDatabase.Connect(new Dictionary()); - IArrowArrayStream stream = adbcConnection.GetObjects( - depth: AdbcConnection.GetObjectsDepth.All, - catalogPattern: databaseName, - dbSchemaPattern: schemaName, - tableNamePattern: tableName, - tableTypes: new List { "BASE TABLE", "VIEW" }, - columnNamePattern: columnName); + IArrowArrayStream stream = adbcConnection.GetObjects( + depth: AdbcConnection.GetObjectsDepth.All, + catalogPattern: databaseName, + dbSchemaPattern: schemaName, + tableNamePattern: tableName, + tableTypes: new List { "BASE TABLE", "VIEW" }, + columnNamePattern: columnName); - RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; + RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result; - List catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName); + List catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName); - List columns = catalogs - .Select(s => s.DbSchemas) - .FirstOrDefault() - .Select(t => t.Tables) - .FirstOrDefault() - .Select(c => c.Columns) - .FirstOrDefault(); + List columns = catalogs + .Select(s => s.DbSchemas) + .FirstOrDefault() + .Select(t => t.Tables) + .FirstOrDefault() + .Select(c => c.Columns) + .FirstOrDefault(); - Assert.IsTrue(columns != null, "Columns cannot be null"); + Assert.True(columns != null, "Columns cannot be null"); - Assert.AreEqual(metadataTestConfiguration.Metadata.ExpectedColumnCount, columns.Count); - } + Assert.Equal(testConfiguration.Metadata.ExpectedColumnCount, columns.Count); } /// /// Validates if the driver can call GetTableSchema. /// - [Test, Order(4)] + [SkippableFact, Order(4)] public void CanGetTableSchema() { - if (Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE)) - { - Dictionary parameters = new Dictionary(); + Dictionary parameters = new Dictionary(); - SnowflakeTestConfiguration metadataTestConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); + SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - AdbcDriver driver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(metadataTestConfiguration, out parameters); + AdbcDriver driver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(testConfiguration, out parameters); - AdbcDatabase adbcDatabase = driver.Open(parameters); - AdbcConnection adbcConnection = adbcDatabase.Connect(new Dictionary()); + AdbcDatabase adbcDatabase = driver.Open(parameters); + AdbcConnection adbcConnection = adbcDatabase.Connect(new Dictionary()); - string databaseName = metadataTestConfiguration.Metadata.Catalog; - string schemaName = metadataTestConfiguration.Metadata.Schema; - string tableName = metadataTestConfiguration.Metadata.Table; + string databaseName = testConfiguration.Metadata.Catalog; + string schemaName = testConfiguration.Metadata.Schema; + string tableName = testConfiguration.Metadata.Table; - Schema schema = adbcConnection.GetTableSchema(databaseName, schemaName, tableName); + Schema schema = adbcConnection.GetTableSchema(databaseName, schemaName, tableName); - int numberOfFields = schema.FieldsList.Count; + int numberOfFields = schema.FieldsList.Count; - Assert.AreEqual(metadataTestConfiguration.Metadata.ExpectedColumnCount, numberOfFields); - } + Assert.Equal(testConfiguration.Metadata.ExpectedColumnCount, numberOfFields); } /// /// Validates if the driver can call GetTableTypes. /// - [Test, Order(5)] + [SkippableFact, Order(5)] public void CanGetTableTypes() { - if (Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE)) - { - Dictionary parameters = new Dictionary(); + Dictionary parameters = new Dictionary(); - SnowflakeTestConfiguration metadataTestConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); + SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - AdbcDriver driver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(metadataTestConfiguration, out parameters); + AdbcDriver driver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(testConfiguration, out parameters); - AdbcDatabase adbcDatabase = driver.Open(parameters); - AdbcConnection adbcConnection = adbcDatabase.Connect(new Dictionary()); + AdbcDatabase adbcDatabase = driver.Open(parameters); + AdbcConnection adbcConnection = adbcDatabase.Connect(new Dictionary()); - IArrowArrayStream arrowArrayStream = adbcConnection.GetTableTypes(); + IArrowArrayStream arrowArrayStream = adbcConnection.GetTableTypes(); - RecordBatch recordBatch = arrowArrayStream.ReadNextRecordBatchAsync().Result; + RecordBatch recordBatch = arrowArrayStream.ReadNextRecordBatchAsync().Result; - StringArray stringArray = (StringArray)recordBatch.Column("table_type"); + StringArray stringArray = (StringArray)recordBatch.Column("table_type"); - List known_types = new List - { - "BASE TABLE", "TEMPORARY TABLE", "VIEW" - }; + List known_types = new List + { + "BASE TABLE", "TEMPORARY TABLE", "VIEW" + }; - int results = 0; + int results = 0; - for (int i = 0; i < stringArray.Length; i++) - { - string value = stringArray.GetString(i); + for (int i = 0; i < stringArray.Length; i++) + { + string value = stringArray.GetString(i); - if (known_types.Contains(value)) - { - results++; - } + if (known_types.Contains(value)) + { + results++; } - - Assert.AreEqual(known_types.Count, results); } + + Assert.Equal(known_types.Count, results); } /// /// Validates if the driver can connect to a live server and /// parse the results. /// - [Test, Order(6)] + [SkippableFact, Order(6)] public void CanExecuteQuery() { - if (Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE)) - { - SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); + SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - Dictionary parameters = new Dictionary(); - Dictionary options = new Dictionary(); + Dictionary parameters = new Dictionary(); + Dictionary options = new Dictionary(); - AdbcDriver snowflakeDriver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(testConfiguration, out parameters); + AdbcDriver snowflakeDriver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(testConfiguration, out parameters); - AdbcDatabase adbcDatabase = snowflakeDriver.Open(parameters); - AdbcConnection adbcConnection = adbcDatabase.Connect(options); + AdbcDatabase adbcDatabase = snowflakeDriver.Open(parameters); + AdbcConnection adbcConnection = adbcDatabase.Connect(options); - Console.WriteLine(testConfiguration.Query); + Console.WriteLine(testConfiguration.Query); - AdbcStatement statement = adbcConnection.CreateStatement(); - statement.SqlQuery = testConfiguration.Query; + AdbcStatement statement = adbcConnection.CreateStatement(); + statement.SqlQuery = testConfiguration.Query; - QueryResult queryResult = statement.ExecuteQuery(); + QueryResult queryResult = statement.ExecuteQuery(); - Tests.DriverTests.CanExecuteQuery(queryResult, testConfiguration.ExpectedResultsCount); - } + Tests.DriverTests.CanExecuteQuery(queryResult, testConfiguration.ExpectedResultsCount); } } } diff --git a/csharp/test/Drivers/Snowflake/SampleData.cs b/csharp/test/Drivers/Snowflake/SampleData.cs index 8307f2c10b..6a99fe1ac5 100644 --- a/csharp/test/Drivers/Snowflake/SampleData.cs +++ b/csharp/test/Drivers/Snowflake/SampleData.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; +using System.Data.SqlTypes; using System.Text; using Apache.Arrow.Types; @@ -36,15 +37,15 @@ public static List GetSampleData() List expectedValues = new List() { // https://github.com/apache/arrow-adbc/issues/1020 has Snowflake treat all values as decimal by default - new ColumnNetTypeArrowTypeValue("NUMBERTYPE", typeof(decimal), typeof(Decimal128Type), 1m), - new ColumnNetTypeArrowTypeValue("DECIMALTYPE", typeof(decimal), typeof(Decimal128Type), 1231m), - new ColumnNetTypeArrowTypeValue("NUMERICTYPE", typeof(decimal), typeof(Decimal128Type), 1231m), - new ColumnNetTypeArrowTypeValue("INTTYPE", typeof(decimal), typeof(Decimal128Type), 123m), - new ColumnNetTypeArrowTypeValue("INTEGERTYPE", typeof(decimal), typeof(Decimal128Type), 123m), - new ColumnNetTypeArrowTypeValue("BIGINTTYPE", typeof(decimal), typeof(Decimal128Type), 123m), - new ColumnNetTypeArrowTypeValue("SMALLINTTYPE", typeof(decimal), typeof(Decimal128Type), 123m), - new ColumnNetTypeArrowTypeValue("TINYINTTYPE", typeof(decimal), typeof(Decimal128Type), 123m), - new ColumnNetTypeArrowTypeValue("BYTEINTTYPE", typeof(decimal), typeof(Decimal128Type), 123m), + new ColumnNetTypeArrowTypeValue("NUMBERTYPE", typeof(SqlDecimal), typeof(Decimal128Type), new SqlDecimal(1m)), + new ColumnNetTypeArrowTypeValue("DECIMALTYPE", typeof(double), typeof(DoubleType), 123.1d), + new ColumnNetTypeArrowTypeValue("NUMERICTYPE", typeof(SqlDecimal), typeof(Decimal128Type), new SqlDecimal(1231m)), + new ColumnNetTypeArrowTypeValue("INTTYPE", typeof(SqlDecimal), typeof(Decimal128Type), new SqlDecimal(123m)), + new ColumnNetTypeArrowTypeValue("INTEGERTYPE", typeof(SqlDecimal), typeof(Decimal128Type), new SqlDecimal(123m)), + new ColumnNetTypeArrowTypeValue("BIGINTTYPE", typeof(SqlDecimal), typeof(Decimal128Type), new SqlDecimal(123m)), + new ColumnNetTypeArrowTypeValue("SMALLINTTYPE", typeof(SqlDecimal), typeof(Decimal128Type), new SqlDecimal(123m)), + new ColumnNetTypeArrowTypeValue("TINYINTTYPE", typeof(SqlDecimal), typeof(Decimal128Type), new SqlDecimal(123m)), + new ColumnNetTypeArrowTypeValue("BYTEINTTYPE", typeof(SqlDecimal), typeof(Decimal128Type), new SqlDecimal(123m)), new ColumnNetTypeArrowTypeValue("FLOATTYPE", typeof(double), typeof(DoubleType), 123.45d), new ColumnNetTypeArrowTypeValue("FLOAT4TYPE", typeof(double), typeof(DoubleType), 123.45d), new ColumnNetTypeArrowTypeValue("FLOAT8TYPE", typeof(double), typeof(DoubleType), 123.45d),