From b6fdff8062d16533f810bc3a51c8e9d524315cee Mon Sep 17 00:00:00 2001 From: Ryan Murray Date: Fri, 3 May 2019 00:47:42 +0200 Subject: [PATCH 01/38] checkpoint commit of low GC Row...don't think its valid though --- pom.xml | 182 +++++++++------- .../java/com/dremio/spark/DefaultSource.java | 2 +- .../dremio/spark/DremioDataSourceReader.java | 36 ---- .../com/dremio/spark/FlightDataReader.java | 200 ++++++++++++++++++ .../dremio/spark/FlightDataReaderFactory.java | 43 ++++ .../dremio/spark/FlightDataSourceReader.java | 124 +++++++++++ .../com/dremio/spark/FlightSparkContext.java | 6 +- .../java/com/dremio/spark/TestConnector.java | 16 +- 8 files changed, 478 insertions(+), 131 deletions(-) delete mode 100644 src/main/java/com/dremio/spark/DremioDataSourceReader.java create mode 100644 src/main/java/com/dremio/spark/FlightDataReader.java create mode 100644 src/main/java/com/dremio/spark/FlightDataReaderFactory.java create mode 100644 src/main/java/com/dremio/spark/FlightDataSourceReader.java diff --git a/pom.xml b/pom.xml index 0a8f136..67d6aa4 100644 --- a/pom.xml +++ b/pom.xml @@ -296,6 +296,16 @@ + + org.codehaus.janino + janino + 3.0.11 + + + + + + org.apache.spark spark-core_2.11 @@ -313,22 +323,30 @@ log4j log4j - - javax.servlet - servlet-api - - - org.codehaus.jackson - jackson-mapper-asl - - - org.codehaus.jackson - jackson-core-asl - - - com.fasterxml.jackson.core - jackson-databind - + + + + + + + + + + + + + + + + + + + + + + + + @@ -352,22 +370,22 @@ javax.servlet servlet-api - - org.codehaus.jackson - jackson-mapper-asl - - - org.codehaus.jackson - jackson-core-asl - - - com.fasterxml.jackson.core - jackson-databind - - - org.apache.arrow - arrow-vector - + + + + + + + + + + + + + + + + @@ -376,43 +394,38 @@ ${arrow.version} shaded - - com.dremio.sabot - dremio-sabot-flight - ${dremio.version} - test - - - com.dremio.sabot - dremio-sabot-kernel - ${dremio.version} - - - org.slf4j - slf4j-log4j12 - - - commons-logging - commons-logging - - - log4j - log4j - - - javax.servlet - servlet-api - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + org.slf4j jul-to-slf4j @@ -451,21 +464,26 @@ 4.11 test - - com.dremio.sabot - dremio-sabot-kernel - ${dremio.version} - tests - test - - - com.dremio - dremio-common - ${dremio.version} - tests - test - - + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/main/java/com/dremio/spark/DefaultSource.java b/src/main/java/com/dremio/spark/DefaultSource.java index 62770b2..2c82006 100644 --- a/src/main/java/com/dremio/spark/DefaultSource.java +++ b/src/main/java/com/dremio/spark/DefaultSource.java @@ -9,6 +9,6 @@ public class DefaultSource implements DataSourceV2, ReadSupport { private final RootAllocator rootAllocator = new RootAllocator(); public DataSourceReader createReader(DataSourceOptions dataSourceOptions) { - return new DremioDataSourceReader(dataSourceOptions, rootAllocator.newChildAllocator(dataSourceOptions.toString(), 0, rootAllocator.getLimit())); + return new FlightDataSourceReader(dataSourceOptions, rootAllocator.newChildAllocator(dataSourceOptions.toString(), 0, rootAllocator.getLimit())); } } diff --git a/src/main/java/com/dremio/spark/DremioDataSourceReader.java b/src/main/java/com/dremio/spark/DremioDataSourceReader.java deleted file mode 100644 index 16b1c75..0000000 --- a/src/main/java/com/dremio/spark/DremioDataSourceReader.java +++ /dev/null @@ -1,36 +0,0 @@ -package com.dremio.spark; - -import org.apache.arrow.flight.FlightClient; -import org.apache.arrow.flight.FlightDescriptor; -import org.apache.arrow.flight.FlightInfo; -import org.apache.arrow.flight.FlightStream; -import org.apache.arrow.flight.Location; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.types.StructType; - -import java.util.List; - -public class DremioDataSourceReader implements DataSourceReader { - private DataSourceOptions dataSourceOptions; - - public DremioDataSourceReader(DataSourceOptions dataSourceOptions, BufferAllocator allocator) { - this.dataSourceOptions = dataSourceOptions; - FlightClient c = new FlightClient(allocator, new Location(dataSourceOptions.get("host").orElse("localhost"), dataSourceOptions.getInt("port", 43430))); - c.authenticateBasic(dataSourceOptions.get("username").orElse("anonymous"), dataSourceOptions.get("password").orElse("")); - FlightInfo info = c.getInfo(FlightDescriptor.path("sys", "options")); - -// FlightStream s = c.getStream(info.getEndpoints().get(0).getTicket()); - } - - public StructType readSchema() { - return null; - } - - public List> createDataReaderFactories() { - return null; - } -} diff --git a/src/main/java/com/dremio/spark/FlightDataReader.java b/src/main/java/com/dremio/spark/FlightDataReader.java new file mode 100644 index 0000000..323a390 --- /dev/null +++ b/src/main/java/com/dremio/spark/FlightDataReader.java @@ -0,0 +1,200 @@ +package com.dremio.spark; + +import com.google.common.collect.Maps; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + + +import java.io.IOException; +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; +import java.util.Map; + +public class FlightDataReader implements DataReader { + private final FlightClient client; + private final FlightStream stream; + private final int schemaSize; + private final BufferAllocator allocator; + private final Schema schema; + private final StructType structType; + private int subCount = 0; + private int currentRow = 0; + + public FlightDataReader( + FlightEndpoint endpoint, + BufferAllocator allocator, + Location defaultLocation, + Schema schema, + StructType structType) { + this.allocator = allocator; + this.schema = schema; + this.structType = structType; + client = new FlightClient(allocator, + (endpoint.getLocations().isEmpty()) ? defaultLocation : endpoint.getLocations().get(0)); //todo multiple locations + stream = client.getStream(endpoint.getTicket()); + schemaSize = structType.size(); + + } + + @Override + public boolean next() throws IOException { + if (subCount == currentRow) { + boolean hasNext = stream.next(); + if (!hasNext) { + return false; + } + subCount = stream.getRoot().getRowCount(); + currentRow = 0; + } + return true; + } + + @Override + public Row get() { + ArrowRow row = new ArrowRow(structType, schema, stream.getRoot(), currentRow, schemaSize); + currentRow++; + return row; + } + + @Override + public void close() throws IOException { + allocator.close(); + try { + client.close(); + stream.close(); + } catch (Exception e) { + throw new IOException(e); + } + } + + public static class ArrowRow implements Row { + + + private VectorSchemaRoot root; + private final int currentRow; + private final int schemaSize; + private final StructType schema; + + + public ArrowRow(StructType schema, + Schema arrowSchema, + VectorSchemaRoot root, + int currentRow, + int schemaSize) { + this.schema = schema; + this.currentRow = currentRow; + this.schemaSize = schemaSize; + this.root = root; + + IntVector iv = (IntVector) root.getVector("c1"); + int value = iv.get(currentRow); + + } + + @Override + public int size() { + return schemaSize; + } + + @Override + public int length() { + return schemaSize; + } + + @Override + public StructType schema() { + return schema; + } + + @Override + public Object apply(int i) { + return get(i); + } + + @Override + public Object get(int i) { + return null; + } + + @Override + public boolean isNullAt(int i) { + return root.getFieldVectors().get(i).isNull(currentRow); + } + + @Override + public boolean getBoolean(int i) { + return ((BitVector) root.getFieldVectors().get(i)). + return super.getBoolean(i); + } + + @Override + public byte getByte(int i) { + return super.getByte(i); + } + + @Override + public short getShort(int i) { + return super.getShort(i); + } + + @Override + public int getInt(int i) { + return super.getInt(i); + } + + @Override + public long getLong(int i) { + return super.getLong(i); + } + + @Override + public float getFloat(int i) { + return super.getFloat(i); + } + + @Override + public double getDouble(int i) { + return super.getDouble(i); + } + + @Override + public String getString(int i) { + return super.getString(i); + } + + @Override + public BigDecimal getDecimal(int i) { + return super.getDecimal(i); + } + + @Override + public Date getDate(int i) { + return super.getDate(i); + } + + @Override + public Timestamp getTimestamp(int i) { + return super.getTimestamp(i); + } + + @Override + public Row copy() { + return this; + } + + } +} diff --git a/src/main/java/com/dremio/spark/FlightDataReaderFactory.java b/src/main/java/com/dremio/spark/FlightDataReaderFactory.java new file mode 100644 index 0000000..1f1ddef --- /dev/null +++ b/src/main/java/com/dremio/spark/FlightDataReaderFactory.java @@ -0,0 +1,43 @@ +package com.dremio.spark; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.types.StructType; + +public class FlightDataReaderFactory implements DataReaderFactory { + + private FlightEndpoint endpoint; + private final BufferAllocator allocator; + private final Location defaultLocation; + private Schema schema; + private StructType structType; + + public FlightDataReaderFactory( + FlightEndpoint endpoint, + BufferAllocator allocator, + Location defaultLocation, + Schema schema, + StructType structType) { + this.endpoint = endpoint; + this.allocator = allocator; + this.defaultLocation = defaultLocation; + this.schema = schema; + this.structType = structType; + } + + @Override + public String[] preferredLocations() { + return endpoint.getLocations().stream().map(Location::getHost).toArray(String[]::new); + } + + @Override + public DataReader createDataReader() { + return new FlightDataReader(endpoint, allocator, defaultLocation, schema, structType); + } +} diff --git a/src/main/java/com/dremio/spark/FlightDataSourceReader.java b/src/main/java/com/dremio/spark/FlightDataSourceReader.java new file mode 100644 index 0000000..f744f27 --- /dev/null +++ b/src/main/java/com/dremio/spark/FlightDataSourceReader.java @@ -0,0 +1,124 @@ +package com.dremio.spark; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import java.util.List; +import java.util.stream.Collectors; + +public class FlightDataSourceReader implements DataSourceReader { + private final FlightInfo info; + private final Location defaultLocation; + private BufferAllocator allocator; + + public FlightDataSourceReader(DataSourceOptions dataSourceOptions, BufferAllocator allocator) { + defaultLocation = new Location( + dataSourceOptions.get("host").orElse("localhost"), + dataSourceOptions.getInt("port", 43430) + ); + this.allocator = allocator; + FlightClient client = new FlightClient( + allocator.newChildAllocator("data-source-reader", 0, allocator.getLimit()), + defaultLocation); + client.authenticateBasic(dataSourceOptions.get("username").orElse("anonymous"), dataSourceOptions.get("password").orElse("")); + info = client.getInfo(FlightDescriptor.path("sys", "options")); + try { + client.close(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + public StructType readSchema() { + StructField[] fields = info.getSchema().getFields().stream() + .map(field -> + new StructField(field.getName(), + sparkFromArrow(field.getFieldType()), + field.isNullable(), + Metadata.empty())) + .toArray(StructField[]::new); + return new StructType(fields); + } + + private DataType sparkFromArrow(FieldType fieldType) { + switch (fieldType.getType().getTypeID()) { + case Null: + return DataTypes.NullType; + case Struct: + throw new UnsupportedOperationException("have not implemented Struct type yet"); + case List: + throw new UnsupportedOperationException("have not implemented List type yet"); + case FixedSizeList: + throw new UnsupportedOperationException("have not implemented FixedSizeList type yet"); + case Union: + throw new UnsupportedOperationException("have not implemented Union type yet"); + case Int: + ArrowType.Int intType = (ArrowType.Int) fieldType.getType(); + int bitWidth = intType.getBitWidth(); + if (bitWidth == 8) { + return DataTypes.ByteType; + } else if (bitWidth == 16) { + return DataTypes.ShortType; + } else if (bitWidth == 32) { + return DataTypes.IntegerType; + } else if (bitWidth == 64) { + return DataTypes.LongType; + } + throw new UnsupportedOperationException("unknow int type with bitwidth " + bitWidth); + case FloatingPoint: + ArrowType.FloatingPoint floatType = (ArrowType.FloatingPoint) fieldType.getType(); + FloatingPointPrecision precision = floatType.getPrecision(); + switch (precision) { + case HALF: + case SINGLE: + return DataTypes.FloatType; + case DOUBLE: + return DataTypes.DoubleType; + } + case Utf8: + return DataTypes.StringType; + case Binary: +// case FixedSizeBinary: + return DataTypes.BinaryType; + case Bool: + return DataTypes.BooleanType; + case Decimal: + throw new UnsupportedOperationException("have not implemented Decimal type yet"); + case Date: + return DataTypes.DateType; + case Time: + return DataTypes.TimestampType; //note i don't know what this will do! + case Timestamp: + return DataTypes.TimestampType; + case Interval: + return DataTypes.CalendarIntervalType; + case NONE: + return DataTypes.NullType; + } + throw new IllegalStateException("Unexpected value: " + fieldType); + } + + public List> createDataReaderFactories() { + return info.getEndpoints().stream().map(endpoint -> + new FlightDataReaderFactory(endpoint, + allocator.newChildAllocator("data-source-reader", 0, allocator.getLimit()), + defaultLocation, + info.getSchema(), + readSchema())).collect(Collectors.toList()); + } +} diff --git a/src/main/java/com/dremio/spark/FlightSparkContext.java b/src/main/java/com/dremio/spark/FlightSparkContext.java index c51e094..38b24f3 100644 --- a/src/main/java/com/dremio/spark/FlightSparkContext.java +++ b/src/main/java/com/dremio/spark/FlightSparkContext.java @@ -4,6 +4,8 @@ import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.DataFrameReader; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; public class FlightSparkContext { @@ -22,8 +24,8 @@ public static FlightSparkContext flightContext(JavaSparkContext sc) { return new FlightSparkContext(sc.sc(), sc.getConf()); } - public void read(String s) { - reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) + public Dataset read(String s) { + return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) .option("host", conf.get("spark.flight.endpoint.host")) .option("username", conf.get("spark.flight.username")) .option("password", conf.get("spark.flight.password")) diff --git a/src/test/java/com/dremio/spark/TestConnector.java b/src/test/java/com/dremio/spark/TestConnector.java index 5889986..3083cf5 100644 --- a/src/test/java/com/dremio/spark/TestConnector.java +++ b/src/test/java/com/dremio/spark/TestConnector.java @@ -1,27 +1,19 @@ package com.dremio.spark; -import com.dremio.BaseTestQuery; -import com.dremio.exec.ExecTest; -import com.dremio.service.InitializerRegistry; import org.apache.spark.SparkConf; -import org.apache.spark.sql.SQLContext; import org.junit.AfterClass; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; -public class TestConnector extends BaseTestQuery { +public class TestConnector { private static SparkConf conf; private static JavaSparkContext sc; private static FlightSparkContext csc; - private static InitializerRegistry registry; @BeforeClass public static void setUp() throws Exception { - registry = new InitializerRegistry(ExecTest.CLASSPATH_SCAN_RESULT, getBindingProvider()); - registry.start(); conf = new SparkConf() .setAppName("flightTest") .setMaster("local[*]") @@ -37,7 +29,6 @@ public static void setUp() throws Exception { @AfterClass public static void tearDown() throws Exception { - registry.close(); sc.close(); } @@ -45,4 +36,9 @@ public static void tearDown() throws Exception { public void testConnect() { csc.read("sys.options"); } + + @Test + public void testRead() { + csc.read("sys.options").show(); + } } From d38c007697e31fc04137436ede9b247148703fac Mon Sep 17 00:00:00 2001 From: Ryan Murray Date: Fri, 3 May 2019 14:47:32 +0100 Subject: [PATCH 02/38] working spark on 2.3 but not super fast --- pom.xml | 22 ++- .../com/dremio/spark/FlightDataReader.java | 174 ++++-------------- .../dremio/spark/FlightDataReaderFactory.java | 32 ++-- .../dremio/spark/FlightDataSourceReader.java | 19 +- .../java/com/dremio/spark/TestConnector.java | 28 ++- 5 files changed, 106 insertions(+), 169 deletions(-) diff --git a/pom.xml b/pom.xml index 67d6aa4..e67c4bd 100644 --- a/pom.xml +++ b/pom.xml @@ -347,6 +347,10 @@ + + org.apache.arrow + arrow-vector + @@ -382,10 +386,10 @@ - - - - + + org.apache.arrow + arrow-vector + @@ -394,6 +398,16 @@ ${arrow.version} shaded + + org.apache.arrow + arrow-vector + ${arrow.version} + + + com.dremio.client + dremio-client-jdbc + ${dremio.version} + diff --git a/src/main/java/com/dremio/spark/FlightDataReader.java b/src/main/java/com/dremio/spark/FlightDataReader.java index 323a390..0c99de9 100644 --- a/src/main/java/com/dremio/spark/FlightDataReader.java +++ b/src/main/java/com/dremio/spark/FlightDataReader.java @@ -1,52 +1,46 @@ package com.dremio.spark; -import com.google.common.collect.Maps; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightEndpoint; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Ticket; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.BitVector; -import org.apache.arrow.vector.IntVector; -import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; import org.apache.spark.sql.sources.v2.reader.DataReader; -import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import java.io.IOException; -import java.math.BigDecimal; -import java.sql.Date; -import java.sql.Timestamp; -import java.util.List; -import java.util.Map; public class FlightDataReader implements DataReader { private final FlightClient client; private final FlightStream stream; - private final int schemaSize; +// private final int schemaSize; private final BufferAllocator allocator; - private final Schema schema; +// private final Schema schema; private final StructType structType; private int subCount = 0; private int currentRow = 0; public FlightDataReader( - FlightEndpoint endpoint, + byte[] ticket, BufferAllocator allocator, - Location defaultLocation, - Schema schema, + String defaultHost, + int defaultPort, +// Schema schema, StructType structType) { - this.allocator = allocator; - this.schema = schema; + this.allocator = new RootAllocator(); +// this.schema = schema; this.structType = structType; - client = new FlightClient(allocator, - (endpoint.getLocations().isEmpty()) ? defaultLocation : endpoint.getLocations().get(0)); //todo multiple locations - stream = client.getStream(endpoint.getTicket()); - schemaSize = structType.size(); + client = new FlightClient(this.allocator, new Location(defaultHost, defaultPort)); //todo multiple locations + client.authenticateBasic("dremio", "dremio123"); + stream = client.getStream(new Ticket(ticket)); +// schemaSize = structType.size(); } @@ -65,136 +59,30 @@ public boolean next() throws IOException { @Override public Row get() { - ArrowRow row = new ArrowRow(structType, schema, stream.getRoot(), currentRow, schemaSize); + Row row = new GenericRowWithSchema(stream.getRoot().getFieldVectors() + .stream() + .map(v -> { + Object o = v.getObject(currentRow); + if (o instanceof Text) { + return o.toString(); + } + return o; + }) + .toArray(Object[]::new), + structType); currentRow++; return row; } @Override public void close() throws IOException { - allocator.close(); + try { - client.close(); - stream.close(); +// client.close(); +// stream.close(); +// allocator.close(); } catch (Exception e) { throw new IOException(e); } } - - public static class ArrowRow implements Row { - - - private VectorSchemaRoot root; - private final int currentRow; - private final int schemaSize; - private final StructType schema; - - - public ArrowRow(StructType schema, - Schema arrowSchema, - VectorSchemaRoot root, - int currentRow, - int schemaSize) { - this.schema = schema; - this.currentRow = currentRow; - this.schemaSize = schemaSize; - this.root = root; - - IntVector iv = (IntVector) root.getVector("c1"); - int value = iv.get(currentRow); - - } - - @Override - public int size() { - return schemaSize; - } - - @Override - public int length() { - return schemaSize; - } - - @Override - public StructType schema() { - return schema; - } - - @Override - public Object apply(int i) { - return get(i); - } - - @Override - public Object get(int i) { - return null; - } - - @Override - public boolean isNullAt(int i) { - return root.getFieldVectors().get(i).isNull(currentRow); - } - - @Override - public boolean getBoolean(int i) { - return ((BitVector) root.getFieldVectors().get(i)). - return super.getBoolean(i); - } - - @Override - public byte getByte(int i) { - return super.getByte(i); - } - - @Override - public short getShort(int i) { - return super.getShort(i); - } - - @Override - public int getInt(int i) { - return super.getInt(i); - } - - @Override - public long getLong(int i) { - return super.getLong(i); - } - - @Override - public float getFloat(int i) { - return super.getFloat(i); - } - - @Override - public double getDouble(int i) { - return super.getDouble(i); - } - - @Override - public String getString(int i) { - return super.getString(i); - } - - @Override - public BigDecimal getDecimal(int i) { - return super.getDecimal(i); - } - - @Override - public Date getDate(int i) { - return super.getDate(i); - } - - @Override - public Timestamp getTimestamp(int i) { - return super.getTimestamp(i); - } - - @Override - public Row copy() { - return this; - } - - } } diff --git a/src/main/java/com/dremio/spark/FlightDataReaderFactory.java b/src/main/java/com/dremio/spark/FlightDataReaderFactory.java index 1f1ddef..a8c9fe4 100644 --- a/src/main/java/com/dremio/spark/FlightDataReaderFactory.java +++ b/src/main/java/com/dremio/spark/FlightDataReaderFactory.java @@ -3,6 +3,7 @@ import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightEndpoint; import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Ticket; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.spark.sql.Row; @@ -12,32 +13,35 @@ public class FlightDataReaderFactory implements DataReaderFactory { - private FlightEndpoint endpoint; - private final BufferAllocator allocator; - private final Location defaultLocation; - private Schema schema; + private byte[] ticket; +// private final BufferAllocator allocator; + private final String defaultHost; + private final int defaultPort; +// private Schema schema; private StructType structType; public FlightDataReaderFactory( - FlightEndpoint endpoint, - BufferAllocator allocator, - Location defaultLocation, - Schema schema, + byte[] ticket, +// BufferAllocator allocator, + String defaultHost, + int defaultPort, +// Schema schema, StructType structType) { - this.endpoint = endpoint; - this.allocator = allocator; - this.defaultLocation = defaultLocation; - this.schema = schema; + this.ticket = ticket; +// this.allocator = allocator; + this.defaultHost = defaultHost; + this.defaultPort = defaultPort; +// this.schema = schema; this.structType = structType; } @Override public String[] preferredLocations() { - return endpoint.getLocations().stream().map(Location::getHost).toArray(String[]::new); + return new String[0]; //endpoint.getLocations().stream().map(Location::getHost).toArray(String[]::new); } @Override public DataReader createDataReader() { - return new FlightDataReader(endpoint, allocator, defaultLocation, schema, structType); + return new FlightDataReader(ticket, null, defaultHost, defaultPort, structType); } } diff --git a/src/main/java/com/dremio/spark/FlightDataSourceReader.java b/src/main/java/com/dremio/spark/FlightDataSourceReader.java index f744f27..6149d1c 100644 --- a/src/main/java/com/dremio/spark/FlightDataSourceReader.java +++ b/src/main/java/com/dremio/spark/FlightDataSourceReader.java @@ -36,7 +36,7 @@ public FlightDataSourceReader(DataSourceOptions dataSourceOptions, BufferAllocat allocator.newChildAllocator("data-source-reader", 0, allocator.getLimit()), defaultLocation); client.authenticateBasic(dataSourceOptions.get("username").orElse("anonymous"), dataSourceOptions.get("password").orElse("")); - info = client.getInfo(FlightDescriptor.path("sys", "options")); + info = client.getInfo(FlightDescriptor.path(dataSourceOptions.get("path").orElse("").split("\\."))); try { client.close(); } catch (InterruptedException e) { @@ -114,11 +114,16 @@ private DataType sparkFromArrow(FieldType fieldType) { } public List> createDataReaderFactories() { - return info.getEndpoints().stream().map(endpoint -> - new FlightDataReaderFactory(endpoint, - allocator.newChildAllocator("data-source-reader", 0, allocator.getLimit()), - defaultLocation, - info.getSchema(), - readSchema())).collect(Collectors.toList()); + + return info.getEndpoints().stream().map(endpoint -> { + Location location = (endpoint.getLocations().isEmpty()) ? + new Location(defaultLocation.getHost(), defaultLocation.getPort()) : + endpoint.getLocations().get(0); + return new FlightDataReaderFactory(endpoint.getTicket().getBytes(), +// allocator.newChildAllocator("data-source-reader", 0, allocator.getLimit()), + location.getHost(), location.getPort(), + //info.getSchema(), + readSchema()); + }).collect(Collectors.toList()); } } diff --git a/src/test/java/com/dremio/spark/TestConnector.java b/src/test/java/com/dremio/spark/TestConnector.java index 3083cf5..be0debc 100644 --- a/src/test/java/com/dremio/spark/TestConnector.java +++ b/src/test/java/com/dremio/spark/TestConnector.java @@ -1,12 +1,17 @@ package com.dremio.spark; import org.apache.spark.SparkConf; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; +import java.util.Properties; + public class TestConnector { private static SparkConf conf; private static JavaSparkContext sc; @@ -24,6 +29,7 @@ public static void setUp() throws Exception { .set("spark.flight.password", "dremio123") ; sc = new JavaSparkContext(conf); + csc = FlightSparkContext.flightContext(sc); } @@ -39,6 +45,26 @@ public void testConnect() { @Test public void testRead() { - csc.read("sys.options").show(); + long[] jdbcT = new long[16]; + long[] flightT = new long[16]; + Properties connectionProperties = new Properties(); + connectionProperties.put("user", "dremio"); + connectionProperties.put("password", "dremio123"); + long jdbcC = 0; + long flightC = 0; + for (int i=0;i<2;i++) { + long now = System.currentTimeMillis(); + Dataset jdbc = SQLContext.getOrCreate(sc.sc()).read().jdbc("jdbc:dremio:direct=localhost:31010", "\"@dremio\".sdd", connectionProperties); + jdbcC = jdbc.count(); + long then = System.currentTimeMillis(); + flightC = csc.read("@dremio.sdd").count(); + long andHereWeAre = System.currentTimeMillis(); + jdbcT[i] = then-now; + flightT[i] = andHereWeAre - then; + } + for (int i =0;i<16;i++) { + System.out.println("Trial " + i + ": Flight took " + flightT[i] + " and jdbc took " + jdbcT[i]); + } + System.out.println("Fetched " + jdbcC + " row from jdbc and " + flightC + " from flight"); } } From e67bb6d92400c69b90e6c2d774099d5dbdd53223 Mon Sep 17 00:00:00 2001 From: Ryan Murray Date: Sun, 5 May 2019 11:51:29 +0100 Subject: [PATCH 03/38] moved to spark 2.4, started to add filter --- pom.xml | 140 ++---------------- .../com/dremio/spark/FlightDataReader.java | 68 +++------ .../dremio/spark/FlightDataReaderFactory.java | 31 +--- .../dremio/spark/FlightDataSourceReader.java | 107 ++++++++++--- .../com/dremio/spark/FlightSparkContext.java | 17 ++- .../java/com/dremio/spark/TestConnector.java | 37 ++++- 6 files changed, 172 insertions(+), 228 deletions(-) diff --git a/pom.xml b/pom.xml index e67c4bd..a8bbaa9 100644 --- a/pom.xml +++ b/pom.xml @@ -11,7 +11,7 @@ 3.1.10-201904162146020182-adf690d 0.14.0-SNAPSHOT - 2.3.3 + 2.4.2 1.7.25 @@ -275,36 +275,13 @@ 8 - - - - - - - - - - - - - - - - - - - - org.codehaus.janino - janino - 3.0.11 - - - - + + + org.apache.spark @@ -323,34 +300,6 @@ log4j log4j - - - - - - - - - - - - - - - - - - - - - - - - - - org.apache.arrow - arrow-vector - @@ -374,22 +323,6 @@ javax.servlet servlet-api - - - - - - - - - - - - - - org.apache.arrow - arrow-vector - @@ -398,46 +331,15 @@ ${arrow.version} shaded - - org.apache.arrow - arrow-vector - ${arrow.version} - - - com.dremio.client - dremio-client-jdbc - ${dremio.version} - - - + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - @@ -478,26 +380,6 @@ 4.11 test - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/src/main/java/com/dremio/spark/FlightDataReader.java b/src/main/java/com/dremio/spark/FlightDataReader.java index 0c99de9..6b96a4d 100644 --- a/src/main/java/com/dremio/spark/FlightDataReader.java +++ b/src/main/java/com/dremio/spark/FlightDataReader.java @@ -1,88 +1,60 @@ package com.dremio.spark; import org.apache.arrow.flight.FlightClient; -import org.apache.arrow.flight.FlightEndpoint; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; import org.apache.arrow.flight.Ticket; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.types.pojo.Schema; -import org.apache.arrow.vector.util.Text; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; -import org.apache.spark.sql.sources.v2.reader.DataReader; -import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.vectorized.ArrowColumnVector; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; import java.io.IOException; -public class FlightDataReader implements DataReader { +public class FlightDataReader implements InputPartitionReader { private final FlightClient client; private final FlightStream stream; -// private final int schemaSize; private final BufferAllocator allocator; -// private final Schema schema; - private final StructType structType; - private int subCount = 0; - private int currentRow = 0; public FlightDataReader( byte[] ticket, - BufferAllocator allocator, String defaultHost, - int defaultPort, -// Schema schema, - StructType structType) { + int defaultPort) { this.allocator = new RootAllocator(); -// this.schema = schema; - this.structType = structType; client = new FlightClient(this.allocator, new Location(defaultHost, defaultPort)); //todo multiple locations client.authenticateBasic("dremio", "dremio123"); stream = client.getStream(new Ticket(ticket)); -// schemaSize = structType.size(); - } @Override public boolean next() throws IOException { - if (subCount == currentRow) { - boolean hasNext = stream.next(); - if (!hasNext) { - return false; - } - subCount = stream.getRoot().getRowCount(); - currentRow = 0; - } - return true; + return stream.next(); } @Override - public Row get() { - Row row = new GenericRowWithSchema(stream.getRoot().getFieldVectors() - .stream() - .map(v -> { - Object o = v.getObject(currentRow); - if (o instanceof Text) { - return o.toString(); - } - return o; - }) - .toArray(Object[]::new), - structType); - currentRow++; - return row; + public ColumnarBatch get() { + ColumnarBatch batch = new ColumnarBatch( + stream.getRoot().getFieldVectors() + .stream() + .map(ArrowColumnVector::new) + .toArray(ColumnVector[]::new) + ); + batch.setNumRows(stream.getRoot().getRowCount()); + return batch; } @Override public void close() throws IOException { - try { +// try { // client.close(); // stream.close(); // allocator.close(); - } catch (Exception e) { - throw new IOException(e); - } +// } catch (Exception e) { +// throw new IOException(e); +// } } } diff --git a/src/main/java/com/dremio/spark/FlightDataReaderFactory.java b/src/main/java/com/dremio/spark/FlightDataReaderFactory.java index a8c9fe4..dfda33e 100644 --- a/src/main/java/com/dremio/spark/FlightDataReaderFactory.java +++ b/src/main/java/com/dremio/spark/FlightDataReaderFactory.java @@ -1,38 +1,22 @@ package com.dremio.spark; -import org.apache.arrow.flight.FlightClient; -import org.apache.arrow.flight.FlightEndpoint; -import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.Ticket; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.types.pojo.Schema; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.sources.v2.reader.DataReader; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; -import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.vectorized.ColumnarBatch; -public class FlightDataReaderFactory implements DataReaderFactory { +public class FlightDataReaderFactory implements InputPartition { private byte[] ticket; -// private final BufferAllocator allocator; private final String defaultHost; private final int defaultPort; -// private Schema schema; - private StructType structType; public FlightDataReaderFactory( byte[] ticket, -// BufferAllocator allocator, String defaultHost, - int defaultPort, -// Schema schema, - StructType structType) { + int defaultPort) { this.ticket = ticket; -// this.allocator = allocator; this.defaultHost = defaultHost; this.defaultPort = defaultPort; -// this.schema = schema; - this.structType = structType; } @Override @@ -41,7 +25,8 @@ public String[] preferredLocations() { } @Override - public DataReader createDataReader() { - return new FlightDataReader(ticket, null, defaultHost, defaultPort, structType); + public InputPartitionReader createPartitionReader() { + return new FlightDataReader(ticket, defaultHost, defaultPort); } + } diff --git a/src/main/java/com/dremio/spark/FlightDataSourceReader.java b/src/main/java/com/dremio/spark/FlightDataSourceReader.java index 6149d1c..9cfd83e 100644 --- a/src/main/java/com/dremio/spark/FlightDataSourceReader.java +++ b/src/main/java/com/dremio/spark/FlightDataSourceReader.java @@ -1,5 +1,6 @@ package com.dremio.spark; +import com.clearspring.analytics.util.Lists; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; @@ -8,35 +9,36 @@ import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.FieldType; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.IsNotNull; import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.reader.DataReaderFactory; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.SupportsPushDownFilters; +import org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnarBatch; import java.util.List; import java.util.stream.Collectors; -public class FlightDataSourceReader implements DataSourceReader { +public class FlightDataSourceReader implements SupportsScanColumnarBatch, SupportsPushDownFilters { private final FlightInfo info; private final Location defaultLocation; - private BufferAllocator allocator; + private Filter[] pushed; public FlightDataSourceReader(DataSourceOptions dataSourceOptions, BufferAllocator allocator) { defaultLocation = new Location( dataSourceOptions.get("host").orElse("localhost"), - dataSourceOptions.getInt("port", 43430) + dataSourceOptions.getInt("port", 47470) ); - this.allocator = allocator; - FlightClient client = new FlightClient( - allocator.newChildAllocator("data-source-reader", 0, allocator.getLimit()), - defaultLocation); - client.authenticateBasic(dataSourceOptions.get("username").orElse("anonymous"), dataSourceOptions.get("password").orElse("")); - info = client.getInfo(FlightDescriptor.path(dataSourceOptions.get("path").orElse("").split("\\."))); + FlightClient client = new FlightClient(allocator,defaultLocation); + client.authenticateBasic(dataSourceOptions.get("username").orElse("anonymous"), dataSourceOptions.get("password").orElse(null)); + info = client.getInfo(getDescriptor(dataSourceOptions)); try { client.close(); } catch (InterruptedException e) { @@ -44,6 +46,43 @@ public FlightDataSourceReader(DataSourceOptions dataSourceOptions, BufferAllocat } } + private FlightDescriptor getDescriptor(DataSourceOptions dataSourceOptions) { + if (dataSourceOptions.getBoolean("isSql", false)) { + return FlightDescriptor.command(dataSourceOptions.get("path").orElse("").getBytes()); + } + String path = dataSourceOptions.get("path").orElse(""); + List paths = Lists.newArrayList(); + StringBuilder current = new StringBuilder(); + boolean isQuote = false; + for (char c: path.toCharArray()) { + if (isQuote && c != '"') { + current.append(c); + } else if (isQuote) { + if (current.length() > 0) { + paths.add(current.toString()); + } + current = new StringBuilder(); + isQuote = false; + } else if (c == '"') { + if (current.length() > 0) { + paths.add(current.toString()); + } + current = new StringBuilder(); + isQuote = true; + } else if (c == '.'){ + if (current.length() > 0) { + paths.add(current.toString()); + } + current = new StringBuilder(); + isQuote = false; + } else { + current.append(c); + } + } + paths.add(current.toString()); + return FlightDescriptor.path(paths); + } + public StructType readSchema() { StructField[] fields = info.getSchema().getFields().stream() .map(field -> @@ -93,7 +132,7 @@ private DataType sparkFromArrow(FieldType fieldType) { case Utf8: return DataTypes.StringType; case Binary: -// case FixedSizeBinary: + case FixedSizeBinary: return DataTypes.BinaryType; case Bool: return DataTypes.BooleanType; @@ -113,17 +152,45 @@ private DataType sparkFromArrow(FieldType fieldType) { throw new IllegalStateException("Unexpected value: " + fieldType); } - public List> createDataReaderFactories() { - + @Override + public List> planBatchInputPartitions() { return info.getEndpoints().stream().map(endpoint -> { Location location = (endpoint.getLocations().isEmpty()) ? new Location(defaultLocation.getHost(), defaultLocation.getPort()) : endpoint.getLocations().get(0); - return new FlightDataReaderFactory(endpoint.getTicket().getBytes(), -// allocator.newChildAllocator("data-source-reader", 0, allocator.getLimit()), - location.getHost(), location.getPort(), - //info.getSchema(), - readSchema()); + return new FlightDataReaderFactory(endpoint.getTicket().getBytes(), + location.getHost(), + location.getPort()); }).collect(Collectors.toList()); } + + @Override + public Filter[] pushFilters(Filter[] filters) { + List notPushed = Lists.newArrayList(); + List pushed = Lists.newArrayList(); + for (Filter filter: filters) { + boolean isPushed = canBePushed(filter); + if (isPushed) { + pushed.add(filter); + } else { + notPushed.add(filter); + } + } + this.pushed = pushed.toArray(new Filter[0]); + return notPushed.toArray(new Filter[0]); + } + + private boolean canBePushed(Filter filter) { + if (filter instanceof IsNotNull) { + return true; + } else if (filter instanceof EqualTo){ + return true; + } + return false; + } + + @Override + public Filter[] pushedFilters() { + return pushed; + } } diff --git a/src/main/java/com/dremio/spark/FlightSparkContext.java b/src/main/java/com/dremio/spark/FlightSparkContext.java index 38b24f3..8d5caf6 100644 --- a/src/main/java/com/dremio/spark/FlightSparkContext.java +++ b/src/main/java/com/dremio/spark/FlightSparkContext.java @@ -10,12 +10,11 @@ public class FlightSparkContext { - private final SQLContext sqlContext; private SparkConf conf; private final DataFrameReader reader; private FlightSparkContext(SparkContext sc, SparkConf conf) { - sqlContext = SQLContext.getOrCreate(sc); + SQLContext sqlContext = SQLContext.getOrCreate(sc); this.conf = conf; reader = sqlContext.read().format("com.dremio.spark"); } @@ -25,10 +24,20 @@ public static FlightSparkContext flightContext(JavaSparkContext sc) { } public Dataset read(String s) { + return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) + .option("host", conf.get("spark.flight.endpoint.host")) + .option("username", conf.get("spark.flight.auth.username")) + .option("password", conf.get("spark.flight.auth.password")) + .option("isSql", false) + .load(s); + } + + public Dataset readSql(String s) { return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) .option("host", conf.get("spark.flight.endpoint.host")) - .option("username", conf.get("spark.flight.username")) - .option("password", conf.get("spark.flight.password")) + .option("username", conf.get("spark.flight.auth.username")) + .option("password", conf.get("spark.flight.auth.password")) + .option("isSql", true) .load(s); } } diff --git a/src/test/java/com/dremio/spark/TestConnector.java b/src/test/java/com/dremio/spark/TestConnector.java index be0debc..13b170c 100644 --- a/src/test/java/com/dremio/spark/TestConnector.java +++ b/src/test/java/com/dremio/spark/TestConnector.java @@ -5,7 +5,9 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.junit.AfterClass; +import org.junit.Assert; import org.junit.BeforeClass; +import org.junit.Ignore; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; @@ -22,11 +24,11 @@ public static void setUp() throws Exception { conf = new SparkConf() .setAppName("flightTest") .setMaster("local[*]") -// .set("spark.driver.allowMultipleContexts","true") + .set("spark.driver.allowMultipleContexts","true") .set("spark.flight.endpoint.host", "localhost") .set("spark.flight.endpoint.port", "47470") - .set("spark.flight.username", "dremio") - .set("spark.flight.password", "dremio123") + .set("spark.flight.auth.username", "dremio") + .set("spark.flight.auth.password", "dremio123") ; sc = new JavaSparkContext(conf); @@ -45,6 +47,33 @@ public void testConnect() { @Test public void testRead() { + long count = csc.read("sys.options").count(); + Assert.assertTrue(count > 0); + } + + @Test + public void testReadWithQuotes() { + long count = csc.read("\"sys\".options").count(); + Assert.assertTrue(count > 0); + } + + @Test + public void testSql() { + long count = csc.readSql("select * from \"sys\".options").count(); + Assert.assertTrue(count > 0); + } + + @Test + public void testFilter() { + Dataset df = csc.readSql("select * from \"sys\".options"); + long count = df.filter(df.col("kind").equalTo("LONG")).count(); + long countOriginal = csc.readSql("select * from \"sys\".options").count(); + Assert.assertTrue(count < countOriginal); + } + + @Ignore + @Test + public void testSpeed() { long[] jdbcT = new long[16]; long[] flightT = new long[16]; Properties connectionProperties = new Properties(); @@ -52,7 +81,7 @@ public void testRead() { connectionProperties.put("password", "dremio123"); long jdbcC = 0; long flightC = 0; - for (int i=0;i<2;i++) { + for (int i=0;i<4;i++) { long now = System.currentTimeMillis(); Dataset jdbc = SQLContext.getOrCreate(sc.sc()).read().jdbc("jdbc:dremio:direct=localhost:31010", "\"@dremio\".sdd", connectionProperties); jdbcC = jdbc.count(); From 78b114ffabc11619ba08bcc37e584c0f6951ce9a Mon Sep 17 00:00:00 2001 From: Ryan Murray Date: Thu, 9 May 2019 15:49:50 +0100 Subject: [PATCH 04/38] mostly add where clause --- pom.xml | 24 ++ .../dremio/proto/flight/commands/Command.java | 216 ++++++++++++++++++ .../com/dremio/spark/FlightClientFactory.java | 27 +++ .../dremio/spark/FlightDataSourceReader.java | 109 +++++---- 4 files changed, 334 insertions(+), 42 deletions(-) create mode 100644 src/main/java/com/dremio/proto/flight/commands/Command.java create mode 100644 src/main/java/com/dremio/spark/FlightClientFactory.java diff --git a/pom.xml b/pom.xml index a8bbaa9..3e6241f 100644 --- a/pom.xml +++ b/pom.xml @@ -13,6 +13,7 @@ 0.14.0-SNAPSHOT 2.4.2 1.7.25 + 1.4.4 @@ -331,6 +332,29 @@ ${arrow.version} shaded + + io.protostuff + protostuff-core + ${protostuff.version} + + + + io.protostuff + protostuff-collectionschema + ${protostuff.version} + + + + io.protostuff + protostuff-runtime + ${protostuff.version} + + + + io.protostuff + protostuff-api + ${protostuff.version} + diff --git a/src/main/java/com/dremio/proto/flight/commands/Command.java b/src/main/java/com/dremio/proto/flight/commands/Command.java new file mode 100644 index 0000000..cf3af29 --- /dev/null +++ b/src/main/java/com/dremio/proto/flight/commands/Command.java @@ -0,0 +1,216 @@ +// Generated by http://code.google.com/p/protostuff/ ... DO NOT EDIT! +// Generated from protobuf + +package com.dremio.proto.flight.commands; + +import javax.annotation.Generated; +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.Objects; + +import io.protostuff.GraphIOUtil; +import io.protostuff.Input; +import io.protostuff.Message; +import io.protostuff.Output; +import io.protostuff.Schema; + +import io.protostuff.UninitializedMessageException; +@Generated("dremio_java_bean.java.stg") +public final class Command implements Externalizable, Message, Schema +{ + + public static Schema getSchema() + { + return DEFAULT_INSTANCE; + } + + public static Command getDefaultInstance() + { + return DEFAULT_INSTANCE; + } + + static final Command DEFAULT_INSTANCE = new Command(); + + + private String + query; + private Boolean + parallel; + + public Command() + { + + } + + public Command( + String query, + Boolean parallel + ) + { + this.query = query; + this.parallel = parallel; + } + + // getters and setters + + // query + public String + getQuery() + { + return query; + } + + public Command setQuery(String + query) + { + this.query = query; + return this; + } + + // parallel + public Boolean + getParallel() + { + return parallel; + } + + public Command setParallel(Boolean + parallel) + { + this.parallel = parallel; + return this; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || this.getClass() != obj.getClass()) { + return false; + } + final Command that = (Command) obj; + return + Objects.equals(this.query, that.query) && + Objects.equals(this.parallel, that.parallel); + } + + @Override + public int hashCode() { + return Objects.hash(query, parallel); + } + + @Override + public String toString() { + return "Command{" + + "query=" + query + + ", parallel=" + parallel + + '}'; + } + // java serialization + + public void readExternal(ObjectInput in) throws IOException + { + GraphIOUtil.mergeDelimitedFrom(in, this, this); + } + + public void writeExternal(ObjectOutput out) throws IOException + { + GraphIOUtil.writeDelimitedTo(out, this, this); + } + + // message method + + public Schema cachedSchema() + { + return DEFAULT_INSTANCE; + } + + // schema methods + + public Command newMessage() + { + return new Command(); + } + + public Class typeClass() + { + return Command.class; + } + + public String messageName() + { + return Command.class.getSimpleName(); + } + + public String messageFullName() + { + return Command.class.getName(); + } + + public boolean isInitialized(Command message) + { + return + message.query != null + && message.parallel != null; + } + + public void mergeFrom(Input input, Command message) throws IOException + { + for(int number = input.readFieldNumber(this);; number = input.readFieldNumber(this)) + { + switch(number) + { + case 0: + return; + case 1: + message.query = input.readString(); + break; + case 2: + message.parallel = input.readBool(); + break; + default: + input.handleUnknownField(number, this); + } + } + } + + + public void writeTo(Output output, Command message) throws IOException + { + if(message.query == null) + throw new UninitializedMessageException(message); + output.writeString(1, message.query, false); + + if(message.parallel == null) + throw new UninitializedMessageException(message); + output.writeBool(2, message.parallel, false); + } + + public String getFieldName(int number) + { + switch(number) + { + case 1: return "query"; + case 2: return "parallel"; + default: return null; + } + } + + public int getFieldNumber(String name) + { + final Integer number = __fieldMap.get(name); + return number == null ? 0 : number.intValue(); + } + + private static final java.util.HashMap __fieldMap = new java.util.HashMap(); + static + { + __fieldMap.put("query", 1); + __fieldMap.put("parallel", 2); + } + + +} diff --git a/src/main/java/com/dremio/spark/FlightClientFactory.java b/src/main/java/com/dremio/spark/FlightClientFactory.java new file mode 100644 index 0000000..f6c778f --- /dev/null +++ b/src/main/java/com/dremio/spark/FlightClientFactory.java @@ -0,0 +1,27 @@ +package com.dremio.spark; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; + +public class FlightClientFactory { + private BufferAllocator allocator; + private Location defaultLocation; + private final String username; + private final String password; + + public FlightClientFactory(BufferAllocator allocator, Location defaultLocation, String username, String password) { + this.allocator = allocator; + this.defaultLocation = defaultLocation; + this.username = username; + this.password = password; + } + + public FlightClient apply() { + FlightClient client = new FlightClient(allocator, defaultLocation); + client.authenticateBasic(username, password); + return client; + + } + +} diff --git a/src/main/java/com/dremio/spark/FlightDataSourceReader.java b/src/main/java/com/dremio/spark/FlightDataSourceReader.java index 9cfd83e..7f5e267 100644 --- a/src/main/java/com/dremio/spark/FlightDataSourceReader.java +++ b/src/main/java/com/dremio/spark/FlightDataSourceReader.java @@ -1,6 +1,10 @@ package com.dremio.spark; import com.clearspring.analytics.util.Lists; +import com.dremio.proto.flight.commands.Command; +import com.google.common.base.Joiner; +import io.protostuff.LinkedBuffer; +import io.protostuff.ProtostuffIOUtil; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; @@ -22,13 +26,22 @@ import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.List; import java.util.stream.Collectors; public class FlightDataSourceReader implements SupportsScanColumnarBatch, SupportsPushDownFilters { - private final FlightInfo info; + private static final Logger LOGGER = LoggerFactory.getLogger(FlightDataSourceReader.class); + private static final Joiner JOINER = Joiner.on(" and "); + private FlightInfo info; + private FlightDescriptor descriptor; + private final LinkedBuffer buffer = LinkedBuffer.allocate(); private final Location defaultLocation; + private final FlightClientFactory clientFactory; + private final boolean parallel; + private String sql; private Filter[] pushed; public FlightDataSourceReader(DataSourceOptions dataSourceOptions, BufferAllocator allocator) { @@ -36,51 +49,26 @@ public FlightDataSourceReader(DataSourceOptions dataSourceOptions, BufferAllocat dataSourceOptions.get("host").orElse("localhost"), dataSourceOptions.getInt("port", 47470) ); - FlightClient client = new FlightClient(allocator,defaultLocation); - client.authenticateBasic(dataSourceOptions.get("username").orElse("anonymous"), dataSourceOptions.get("password").orElse(null)); - info = client.getInfo(getDescriptor(dataSourceOptions)); - try { - client.close(); + clientFactory = new FlightClientFactory(allocator, + defaultLocation, + dataSourceOptions.get("username").orElse("anonymous"), + dataSourceOptions.get("password").orElse(null) + ); + parallel = dataSourceOptions.getBoolean("parallel", false); + sql = dataSourceOptions.get("path").orElse(""); + descriptor = getDescriptor(dataSourceOptions.getBoolean("isSql", false), sql); + try (FlightClient client = clientFactory.apply()) { + info = client.getInfo(descriptor); } catch (InterruptedException e) { - e.printStackTrace(); + throw new RuntimeException(e); } } - private FlightDescriptor getDescriptor(DataSourceOptions dataSourceOptions) { - if (dataSourceOptions.getBoolean("isSql", false)) { - return FlightDescriptor.command(dataSourceOptions.get("path").orElse("").getBytes()); - } - String path = dataSourceOptions.get("path").orElse(""); - List paths = Lists.newArrayList(); - StringBuilder current = new StringBuilder(); - boolean isQuote = false; - for (char c: path.toCharArray()) { - if (isQuote && c != '"') { - current.append(c); - } else if (isQuote) { - if (current.length() > 0) { - paths.add(current.toString()); - } - current = new StringBuilder(); - isQuote = false; - } else if (c == '"') { - if (current.length() > 0) { - paths.add(current.toString()); - } - current = new StringBuilder(); - isQuote = true; - } else if (c == '.'){ - if (current.length() > 0) { - paths.add(current.toString()); - } - current = new StringBuilder(); - isQuote = false; - } else { - current.append(c); - } - } - paths.add(current.toString()); - return FlightDescriptor.path(paths); + private FlightDescriptor getDescriptor(boolean isSql, String path) { + String query = (!isSql) ? ("select * from " + path) : path; + byte[] message = ProtostuffIOUtil.toByteArray(new Command(query , parallel), Command.getSchema(), buffer); + buffer.clear(); + return FlightDescriptor.command(message); } public StructType readSchema() { @@ -177,15 +165,52 @@ public Filter[] pushFilters(Filter[] filters) { } } this.pushed = pushed.toArray(new Filter[0]); + if (!pushed.isEmpty()) { + String whereClause = generateWhereClause(pushed); + mergeWhereDescriptors(whereClause); + try (FlightClient client = clientFactory.apply()) { + info = client.getInfo(descriptor); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } return notPushed.toArray(new Filter[0]); } + private void mergeWhereDescriptors(String whereClause) { + if (sql.contains(" where ")) { + throw new UnsupportedOperationException("have not yet done the regex to insert where clauses"); + } + sql += " where " + whereClause; + descriptor = getDescriptor(true, sql); + } + + private String generateWhereClause(List pushed) { + List filterStr = Lists.newArrayList(); + for (Filter filter: pushed) { + if (filter instanceof IsNotNull) { + filterStr.add(String.format("isnotnull(\"%s\")", ((IsNotNull) filter).attribute())); + } else if (filter instanceof EqualTo){ + filterStr.add(String.format("\"%s\" = %s", ((EqualTo) filter).attribute(), valueToString(((EqualTo) filter).value()))); + } + } + return JOINER.join(filterStr); + } + + private String valueToString(Object value) { + if (value instanceof String) { + return String.format("'%s'", value); + } + return value.toString(); + } + private boolean canBePushed(Filter filter) { if (filter instanceof IsNotNull) { return true; } else if (filter instanceof EqualTo){ return true; } + LOGGER.error("Cant push filter of type " + filter.toString()); return false; } From 01701e023b5aacf5647cf21bcc19356f0d6bbf2d Mon Sep 17 00:00:00 2001 From: Ryan Murray Date: Wed, 22 May 2019 17:26:22 +0100 Subject: [PATCH 05/38] start parallel read --- .../dremio/spark/FlightDataReaderFactory.java | 2 +- .../dremio/spark/FlightDataSourceReader.java | 61 ++++++++++++++++--- .../com/dremio/spark/FlightSparkContext.java | 22 +++++++ .../java/com/dremio/spark/TestConnector.java | 37 +++++++++++ 4 files changed, 112 insertions(+), 10 deletions(-) diff --git a/src/main/java/com/dremio/spark/FlightDataReaderFactory.java b/src/main/java/com/dremio/spark/FlightDataReaderFactory.java index dfda33e..6a50db6 100644 --- a/src/main/java/com/dremio/spark/FlightDataReaderFactory.java +++ b/src/main/java/com/dremio/spark/FlightDataReaderFactory.java @@ -21,7 +21,7 @@ public FlightDataReaderFactory( @Override public String[] preferredLocations() { - return new String[0]; //endpoint.getLocations().stream().map(Location::getHost).toArray(String[]::new); + return new String[]{defaultHost}; } @Override diff --git a/src/main/java/com/dremio/spark/FlightDataSourceReader.java b/src/main/java/com/dremio/spark/FlightDataSourceReader.java index 7f5e267..a6c747d 100644 --- a/src/main/java/com/dremio/spark/FlightDataSourceReader.java +++ b/src/main/java/com/dremio/spark/FlightDataSourceReader.java @@ -1,8 +1,8 @@ package com.dremio.spark; -import com.clearspring.analytics.util.Lists; import com.dremio.proto.flight.commands.Command; import com.google.common.base.Joiner; +import com.google.common.collect.Lists; import io.protostuff.LinkedBuffer; import io.protostuff.ProtostuffIOUtil; import org.apache.arrow.flight.FlightClient; @@ -19,6 +19,7 @@ import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.SupportsPushDownFilters; +import org.apache.spark.sql.sources.v2.reader.SupportsPushDownRequiredColumns; import org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; @@ -30,13 +31,18 @@ import org.slf4j.LoggerFactory; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; -public class FlightDataSourceReader implements SupportsScanColumnarBatch, SupportsPushDownFilters { +import scala.collection.JavaConversions; + +public class FlightDataSourceReader implements SupportsScanColumnarBatch, SupportsPushDownFilters, SupportsPushDownRequiredColumns { private static final Logger LOGGER = LoggerFactory.getLogger(FlightDataSourceReader.class); - private static final Joiner JOINER = Joiner.on(" and "); + private static final Joiner WHERE_JOINER = Joiner.on(" and "); + private static final Joiner PROJ_JOINER = Joiner.on(", "); private FlightInfo info; private FlightDescriptor descriptor; + private StructType schema; private final LinkedBuffer buffer = LinkedBuffer.allocate(); private final Location defaultLocation; private final FlightClientFactory clientFactory; @@ -71,7 +77,7 @@ private FlightDescriptor getDescriptor(boolean isSql, String path) { return FlightDescriptor.command(message); } - public StructType readSchema() { + private StructType readSchemaImpl() { StructField[] fields = info.getSchema().getFields().stream() .map(field -> new StructField(field.getName(), @@ -82,6 +88,13 @@ public StructType readSchema() { return new StructType(fields); } + public StructType readSchema() { + if (schema == null) { + schema = readSchemaImpl(); + } + return schema; + } + private DataType sparkFromArrow(FieldType fieldType) { switch (fieldType.getType().getTypeID()) { case Null: @@ -178,10 +191,12 @@ public Filter[] pushFilters(Filter[] filters) { } private void mergeWhereDescriptors(String whereClause) { - if (sql.contains(" where ")) { - throw new UnsupportedOperationException("have not yet done the regex to insert where clauses"); - } - sql += " where " + whereClause; + sql = String.format("select * from (%s) where %s", sql, whereClause); + descriptor = getDescriptor(true, sql); + } + + private void mergeProjDescriptors(String projClause) { + sql = String.format("select %s from (%s)", projClause, sql); descriptor = getDescriptor(true, sql); } @@ -194,7 +209,7 @@ private String generateWhereClause(List pushed) { filterStr.add(String.format("\"%s\" = %s", ((EqualTo) filter).attribute(), valueToString(((EqualTo) filter).value()))); } } - return JOINER.join(filterStr); + return WHERE_JOINER.join(filterStr); } private String valueToString(Object value) { @@ -218,4 +233,32 @@ private boolean canBePushed(Filter filter) { public Filter[] pushedFilters() { return pushed; } + + @Override + public void pruneColumns(StructType requiredSchema) { + if (requiredSchema.toSeq().isEmpty()) { + return; + } + StructType schema = readSchema(); + List fields = Lists.newArrayList(); + List fieldsLeft = Lists.newArrayList(); + Map fieldNames = JavaConversions.seqAsJavaList(schema.toSeq()).stream().collect(Collectors.toMap(StructField::name, f->f)); + for (StructField field: JavaConversions.seqAsJavaList(requiredSchema.toSeq())) { + String name = field.name(); + StructField f = fieldNames.remove(name); + if (f != null) { + fields.add(String.format("\"%s\"",name)); + fieldsLeft.add(f); + } + } + if (!fieldNames.isEmpty()) { + this.schema = new StructType(fieldsLeft.toArray(new StructField[0])); + mergeProjDescriptors(PROJ_JOINER.join(fields)); + try (FlightClient client = clientFactory.apply()) { + info = client.getInfo(descriptor); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } } diff --git a/src/main/java/com/dremio/spark/FlightSparkContext.java b/src/main/java/com/dremio/spark/FlightSparkContext.java index 8d5caf6..b7c5ae2 100644 --- a/src/main/java/com/dremio/spark/FlightSparkContext.java +++ b/src/main/java/com/dremio/spark/FlightSparkContext.java @@ -29,6 +29,7 @@ public Dataset read(String s) { .option("username", conf.get("spark.flight.auth.username")) .option("password", conf.get("spark.flight.auth.password")) .option("isSql", false) + .option("parallel", false) .load(s); } @@ -38,6 +39,27 @@ public Dataset readSql(String s) { .option("username", conf.get("spark.flight.auth.username")) .option("password", conf.get("spark.flight.auth.password")) .option("isSql", true) + .option("parallel", false) + .load(s); + } + + public Dataset read(String s, boolean parallel) { + return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) + .option("host", conf.get("spark.flight.endpoint.host")) + .option("username", conf.get("spark.flight.auth.username")) + .option("password", conf.get("spark.flight.auth.password")) + .option("isSql", false) + .option("parallel", parallel) + .load(s); + } + + public Dataset readSql(String s, boolean parallel) { + return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) + .option("host", conf.get("spark.flight.endpoint.host")) + .option("username", conf.get("spark.flight.auth.username")) + .option("password", conf.get("spark.flight.auth.password")) + .option("isSql", true) + .option("parallel", parallel) .load(s); } } diff --git a/src/test/java/com/dremio/spark/TestConnector.java b/src/test/java/com/dremio/spark/TestConnector.java index 13b170c..b3e8f4f 100644 --- a/src/test/java/com/dremio/spark/TestConnector.java +++ b/src/test/java/com/dremio/spark/TestConnector.java @@ -13,6 +13,7 @@ import org.apache.spark.api.java.JavaSparkContext; import java.util.Properties; +import java.util.function.Consumer; public class TestConnector { private static SparkConf conf; @@ -71,6 +72,42 @@ public void testFilter() { Assert.assertTrue(count < countOriginal); } + private static class SizeConsumer implements Consumer { + private int length = 0; + private int width = 0; + + @Override + public void accept(Row row) { + length+=1; + width = row.length(); + } + } + + @Test + public void testProject() { + Dataset df = csc.readSql("select * from \"sys\".options"); + SizeConsumer c = new SizeConsumer(); + df.select("name", "kind", "type").toLocalIterator().forEachRemaining(c); + long count = c.width; + long countOriginal = csc.readSql("select * from \"sys\".options").columns().length; + Assert.assertTrue(count < countOriginal); + } + + @Test + public void testParallel() { + Dataset df = csc.readSql("select * from \"sys\".options", true); + SizeConsumer c = new SizeConsumer(); + SizeConsumer c2 = new SizeConsumer(); + df.select("name", "kind", "type").filter(df.col("kind").equalTo("LONG")).toLocalIterator().forEachRemaining(c); + long width = c.width; + long length = c.length; + csc.readSql("select * from \"sys\".options", true).toLocalIterator().forEachRemaining(c2); + long widthOriginal = c2.width; + long lengthOriginal = c2.length; + Assert.assertTrue(width < widthOriginal); + Assert.assertTrue(length < lengthOriginal); + } + @Ignore @Test public void testSpeed() { From de8e262b2a77a9c65b1b4a88094ac3d76b3e5578 Mon Sep 17 00:00:00 2001 From: Ryan Murray Date: Fri, 21 Jun 2019 20:11:38 +0100 Subject: [PATCH 06/38] finished and working with dremio --- pom.xml | 2 +- .../dremio/proto/flight/commands/Command.java | 69 +++++++++++++++++-- .../com/dremio/spark/FlightClientFactory.java | 2 +- .../com/dremio/spark/FlightDataReader.java | 2 +- .../dremio/spark/FlightDataSourceReader.java | 29 ++++++-- .../java/com/dremio/spark/TestConnector.java | 9 ++- 6 files changed, 98 insertions(+), 15 deletions(-) diff --git a/pom.xml b/pom.xml index 3e6241f..c29fe16 100644 --- a/pom.xml +++ b/pom.xml @@ -9,7 +9,7 @@ 1.0-SNAPSHOT - 3.1.10-201904162146020182-adf690d + 3.2.4-201906051751050278-1bcce62 0.14.0-SNAPSHOT 2.4.2 1.7.25 diff --git a/src/main/java/com/dremio/proto/flight/commands/Command.java b/src/main/java/com/dremio/proto/flight/commands/Command.java index cf3af29..3bc1376 100644 --- a/src/main/java/com/dremio/proto/flight/commands/Command.java +++ b/src/main/java/com/dremio/proto/flight/commands/Command.java @@ -10,6 +10,7 @@ import java.io.ObjectOutput; import java.util.Objects; +import io.protostuff.ByteString; import io.protostuff.GraphIOUtil; import io.protostuff.Input; import io.protostuff.Message; @@ -38,6 +39,10 @@ public static Command getDefaultInstance() query; private Boolean parallel; + private Boolean + coalesce; + private ByteString + ticket; public Command() { @@ -46,11 +51,15 @@ public Command() public Command( String query, - Boolean parallel + Boolean parallel, + Boolean coalesce, + ByteString ticket ) { this.query = query; this.parallel = parallel; + this.coalesce = coalesce; + this.ticket = ticket; } // getters and setters @@ -83,6 +92,34 @@ public Command setParallel(Boolean return this; } + // coalesce + public Boolean + getCoalesce() + { + return coalesce; + } + + public Command setCoalesce(Boolean + coalesce) + { + this.coalesce = coalesce; + return this; + } + + // ticket + public ByteString + getTicket() + { + return ticket; + } + + public Command setTicket(ByteString + ticket) + { + this.ticket = ticket; + return this; + } + @Override public boolean equals(Object obj) { if (this == obj) { @@ -94,12 +131,14 @@ public boolean equals(Object obj) { final Command that = (Command) obj; return Objects.equals(this.query, that.query) && - Objects.equals(this.parallel, that.parallel); + Objects.equals(this.parallel, that.parallel) && + Objects.equals(this.coalesce, that.coalesce) && + Objects.equals(this.ticket, that.ticket); } @Override public int hashCode() { - return Objects.hash(query, parallel); + return Objects.hash(query, parallel, coalesce, ticket); } @Override @@ -107,6 +146,8 @@ public String toString() { return "Command{" + "query=" + query + ", parallel=" + parallel + + ", coalesce=" + coalesce + + ", ticket=" + ticket + '}'; } // java serialization @@ -154,7 +195,9 @@ public boolean isInitialized(Command message) { return message.query != null - && message.parallel != null; + && message.parallel != null + && message.coalesce != null + && message.ticket != null; } public void mergeFrom(Input input, Command message) throws IOException @@ -171,6 +214,12 @@ public void mergeFrom(Input input, Command message) throws IOException case 2: message.parallel = input.readBool(); break; + case 3: + message.coalesce = input.readBool(); + break; + case 4: + message.ticket = input.readBytes(); + break; default: input.handleUnknownField(number, this); } @@ -187,6 +236,14 @@ public void writeTo(Output output, Command message) throws IOException if(message.parallel == null) throw new UninitializedMessageException(message); output.writeBool(2, message.parallel, false); + + if(message.coalesce == null) + throw new UninitializedMessageException(message); + output.writeBool(3, message.coalesce, false); + + if(message.ticket == null) + throw new UninitializedMessageException(message); + output.writeBytes(4, message.ticket, false); } public String getFieldName(int number) @@ -195,6 +252,8 @@ public String getFieldName(int number) { case 1: return "query"; case 2: return "parallel"; + case 3: return "coalesce"; + case 4: return "ticket"; default: return null; } } @@ -210,6 +269,8 @@ public int getFieldNumber(String name) { __fieldMap.put("query", 1); __fieldMap.put("parallel", 2); + __fieldMap.put("coalesce", 3); + __fieldMap.put("ticket", 4); } diff --git a/src/main/java/com/dremio/spark/FlightClientFactory.java b/src/main/java/com/dremio/spark/FlightClientFactory.java index f6c778f..f95be42 100644 --- a/src/main/java/com/dremio/spark/FlightClientFactory.java +++ b/src/main/java/com/dremio/spark/FlightClientFactory.java @@ -18,7 +18,7 @@ public FlightClientFactory(BufferAllocator allocator, Location defaultLocation, } public FlightClient apply() { - FlightClient client = new FlightClient(allocator, defaultLocation); + FlightClient client = FlightClient.builder(allocator, defaultLocation).build(); client.authenticateBasic(username, password); return client; diff --git a/src/main/java/com/dremio/spark/FlightDataReader.java b/src/main/java/com/dremio/spark/FlightDataReader.java index 6b96a4d..e28852f 100644 --- a/src/main/java/com/dremio/spark/FlightDataReader.java +++ b/src/main/java/com/dremio/spark/FlightDataReader.java @@ -24,7 +24,7 @@ public FlightDataReader( String defaultHost, int defaultPort) { this.allocator = new RootAllocator(); - client = new FlightClient(this.allocator, new Location(defaultHost, defaultPort)); //todo multiple locations + client = FlightClient.builder(this.allocator, Location.forGrpcInsecure(defaultHost, defaultPort)).build(); //todo multiple locations client.authenticateBasic("dremio", "dremio123"); stream = client.getStream(new Ticket(ticket)); } diff --git a/src/main/java/com/dremio/spark/FlightDataSourceReader.java b/src/main/java/com/dremio/spark/FlightDataSourceReader.java index a6c747d..2e5ebda 100644 --- a/src/main/java/com/dremio/spark/FlightDataSourceReader.java +++ b/src/main/java/com/dremio/spark/FlightDataSourceReader.java @@ -3,6 +3,7 @@ import com.dremio.proto.flight.commands.Command; import com.google.common.base.Joiner; import com.google.common.collect.Lists; +import io.protostuff.ByteString; import io.protostuff.LinkedBuffer; import io.protostuff.ProtostuffIOUtil; import org.apache.arrow.flight.FlightClient; @@ -51,7 +52,7 @@ public class FlightDataSourceReader implements SupportsScanColumnarBatch, Suppor private Filter[] pushed; public FlightDataSourceReader(DataSourceOptions dataSourceOptions, BufferAllocator allocator) { - defaultLocation = new Location( + defaultLocation = Location.forGrpcInsecure( dataSourceOptions.get("host").orElse("localhost"), dataSourceOptions.getInt("port", 47470) ); @@ -72,7 +73,7 @@ public FlightDataSourceReader(DataSourceOptions dataSourceOptions, BufferAllocat private FlightDescriptor getDescriptor(boolean isSql, String path) { String query = (!isSql) ? ("select * from " + path) : path; - byte[] message = ProtostuffIOUtil.toByteArray(new Command(query , parallel), Command.getSchema(), buffer); + byte[] message = ProtostuffIOUtil.toByteArray(new Command(query , parallel, false, ByteString.EMPTY), Command.getSchema(), buffer); buffer.clear(); return FlightDescriptor.command(message); } @@ -155,13 +156,31 @@ private DataType sparkFromArrow(FieldType fieldType) { @Override public List> planBatchInputPartitions() { + if (parallel) { + return planBatchInputPartitionsParallel(); + } + return planBatchInputPartitionsSerial(info); + } + + private List> planBatchInputPartitionsParallel() { + byte[] message = ProtostuffIOUtil.toByteArray(new Command("", true, true, ByteString.copyFrom(info.getEndpoints().get(0).getTicket().getBytes())), Command.getSchema(), buffer); + buffer.clear(); + try (FlightClient client = clientFactory.apply()) { + FlightInfo info = client.getInfo(FlightDescriptor.command(message)); + return planBatchInputPartitionsSerial(info); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + private List> planBatchInputPartitionsSerial(FlightInfo info) { return info.getEndpoints().stream().map(endpoint -> { Location location = (endpoint.getLocations().isEmpty()) ? - new Location(defaultLocation.getHost(), defaultLocation.getPort()) : + Location.forGrpcInsecure(defaultLocation.getUri().getHost(), defaultLocation.getUri().getPort()) : endpoint.getLocations().get(0); return new FlightDataReaderFactory(endpoint.getTicket().getBytes(), - location.getHost(), - location.getPort()); + location.getUri().getHost(), + location.getUri().getPort()); }).collect(Collectors.toList()); } diff --git a/src/test/java/com/dremio/spark/TestConnector.java b/src/test/java/com/dremio/spark/TestConnector.java index b3e8f4f..9dcb940 100644 --- a/src/test/java/com/dremio/spark/TestConnector.java +++ b/src/test/java/com/dremio/spark/TestConnector.java @@ -95,13 +95,16 @@ public void testProject() { @Test public void testParallel() { - Dataset df = csc.readSql("select * from \"sys\".options", true); + String easySql = "select * from sys.options"; + String hardSql = "select * from \"@dremio\".test"; + Dataset df = csc.readSql(hardSql, true); SizeConsumer c = new SizeConsumer(); SizeConsumer c2 = new SizeConsumer(); - df.select("name", "kind", "type").filter(df.col("kind").equalTo("LONG")).toLocalIterator().forEachRemaining(c); + Dataset dff = df.select("bid", "ask", "symbol").filter(df.col("symbol").equalTo("USDCAD")); + dff.toLocalIterator().forEachRemaining(c); long width = c.width; long length = c.length; - csc.readSql("select * from \"sys\".options", true).toLocalIterator().forEachRemaining(c2); + csc.readSql(hardSql, true).toLocalIterator().forEachRemaining(c2); long widthOriginal = c2.width; long lengthOriginal = c2.length; Assert.assertTrue(width < widthOriginal); From bf6942541fd70bda36a8c81054018f59fa1f25a7 Mon Sep 17 00:00:00 2001 From: Ryan Murray Date: Tue, 15 Oct 2019 12:01:23 +0100 Subject: [PATCH 07/38] Create README.md --- README.md | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..07410c4 --- /dev/null +++ b/README.md @@ -0,0 +1,23 @@ +Spark source for Flight enabled endpoints +========================================= + +This uses the new [Source V2 Interface](https://databricks.com/session/apache-spark-data-source-v2) to connect to +[Apache Arrow Flight](https://www.dremio.com/understanding-apache-arrow-flight/) endpoints. It is a prototype of what is +possible with Arrow Flight. The prototype has achieved 50x speed up compared to serial jdbc driver and scales with the +number of Flight endpoints/spark executors being run in parallel. + +It currently supports: + +* Columnar Batch reading +* Reading in parallel many flight endpoints as Spark partitions +* filter and project pushdown + +It currently lacks: + +* support for all Spark/Arrow data types and filters +* Strongly tied to [Dremio's flight endpoint](https://github.com/dremio-hub/dremio-flight-connector) and should be abstracted +to generic Flight sources +* Needs to be updated to support new features in Arrow 0.15.0 +* write interface to use `DoPut` to write Spark dataframes back to an Arrow Flight endpoint +* leverage the transactional capabilities of the Spark Source V2 interface +* proper benchmark test From 9b123640ab634f987e01f28dfc0f073b6c6e4fbe Mon Sep 17 00:00:00 2001 From: Ryan Murray Date: Tue, 15 Oct 2019 12:01:48 +0100 Subject: [PATCH 08/38] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 07410c4..996f6d9 100644 --- a/README.md +++ b/README.md @@ -21,3 +21,4 @@ to generic Flight sources * write interface to use `DoPut` to write Spark dataframes back to an Arrow Flight endpoint * leverage the transactional capabilities of the Spark Source V2 interface * proper benchmark test +* CI build & tests From 408f0e23ba526e6f8ce12cfb5a11f7340ef48bf9 Mon Sep 17 00:00:00 2001 From: Ryan Murray Date: Tue, 15 Oct 2019 12:04:06 +0100 Subject: [PATCH 09/38] license etc --- .editorconfig | 27 ++ .gitignore | 1 + .mvn/extensions.xml | 30 +++ .mvn/wrapper/MavenWrapperDownloader.java | 117 +++++++++ .mvn/wrapper/maven-wrapper.properties | 2 + .travis.yml | 8 + LICENSE | 202 +++++++++++++++ NOTICE | 6 + mvnw | 310 +++++++++++++++++++++++ mvnw.cmd | 182 +++++++++++++ 10 files changed, 885 insertions(+) create mode 100644 .editorconfig create mode 100644 .mvn/extensions.xml create mode 100644 .mvn/wrapper/MavenWrapperDownloader.java create mode 100644 .mvn/wrapper/maven-wrapper.properties create mode 100644 .travis.yml create mode 100755 LICENSE create mode 100755 NOTICE create mode 100755 mvnw create mode 100644 mvnw.cmd diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..daa1c54 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,27 @@ +# +# Copyright (C) 2017-2019 Dremio Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +root = true + +[*] +end_of_line = lf +insert_final_newline = true +indent_size = 2 +indent_style = space + +[*.js] +trim_trailing_whitespace = true diff --git a/.gitignore b/.gitignore index 25f17f2..2f15439 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ .checkstyle .classpath .idea/ +.vscode/ .project .mvn/wrapper/maven-wrapper.jar .profiler diff --git a/.mvn/extensions.xml b/.mvn/extensions.xml new file mode 100644 index 0000000..b9a3245 --- /dev/null +++ b/.mvn/extensions.xml @@ -0,0 +1,30 @@ + + + + + fr.jcgay.maven + maven-profiler + 2.6 + + + fr.jcgay.maven + maven-notifier + 1.10.1 + + diff --git a/.mvn/wrapper/MavenWrapperDownloader.java b/.mvn/wrapper/MavenWrapperDownloader.java new file mode 100644 index 0000000..1ef8d69 --- /dev/null +++ b/.mvn/wrapper/MavenWrapperDownloader.java @@ -0,0 +1,117 @@ +/* + * Copyright 2007-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import java.net.*; +import java.io.*; +import java.nio.channels.*; +import java.util.Properties; + +public class MavenWrapperDownloader { + + private static final String WRAPPER_VERSION = "0.5.4"; + /** + * Default URL to download the maven-wrapper.jar from, if no 'downloadUrl' is provided. + */ + private static final String DEFAULT_DOWNLOAD_URL = "https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/" + + WRAPPER_VERSION + "/maven-wrapper-" + WRAPPER_VERSION + " .jar"; + + /** + * Path to the maven-wrapper.properties file, which might contain a downloadUrl property to + * use instead of the default one. + */ + private static final String MAVEN_WRAPPER_PROPERTIES_PATH = + ".mvn/wrapper/maven-wrapper.properties"; + + /** + * Path where the maven-wrapper.jar will be saved to. + */ + private static final String MAVEN_WRAPPER_JAR_PATH = + ".mvn/wrapper/maven-wrapper.jar"; + + /** + * Name of the property which should be used to override the default download url for the wrapper. + */ + private static final String PROPERTY_NAME_WRAPPER_URL = "wrapperUrl"; + + public static void main(String args[]) { + System.out.println("- Downloader started"); + File baseDirectory = new File(args[0]); + System.out.println("- Using base directory: " + baseDirectory.getAbsolutePath()); + + // If the maven-wrapper.properties exists, read it and check if it contains a custom + // wrapperUrl parameter. + File mavenWrapperPropertyFile = new File(baseDirectory, MAVEN_WRAPPER_PROPERTIES_PATH); + String url = DEFAULT_DOWNLOAD_URL; + if(mavenWrapperPropertyFile.exists()) { + FileInputStream mavenWrapperPropertyFileInputStream = null; + try { + mavenWrapperPropertyFileInputStream = new FileInputStream(mavenWrapperPropertyFile); + Properties mavenWrapperProperties = new Properties(); + mavenWrapperProperties.load(mavenWrapperPropertyFileInputStream); + url = mavenWrapperProperties.getProperty(PROPERTY_NAME_WRAPPER_URL, url); + } catch (IOException e) { + System.out.println("- ERROR loading '" + MAVEN_WRAPPER_PROPERTIES_PATH + "'"); + } finally { + try { + if(mavenWrapperPropertyFileInputStream != null) { + mavenWrapperPropertyFileInputStream.close(); + } + } catch (IOException e) { + // Ignore ... + } + } + } + System.out.println("- Downloading from: " + url); + + File outputFile = new File(baseDirectory.getAbsolutePath(), MAVEN_WRAPPER_JAR_PATH); + if(!outputFile.getParentFile().exists()) { + if(!outputFile.getParentFile().mkdirs()) { + System.out.println( + "- ERROR creating output directory '" + outputFile.getParentFile().getAbsolutePath() + "'"); + } + } + System.out.println("- Downloading to: " + outputFile.getAbsolutePath()); + try { + downloadFileFromURL(url, outputFile); + System.out.println("Done"); + System.exit(0); + } catch (Throwable e) { + System.out.println("- Error downloading"); + e.printStackTrace(); + System.exit(1); + } + } + + private static void downloadFileFromURL(String urlString, File destination) throws Exception { + if (System.getenv("MVNW_USERNAME") != null && System.getenv("MVNW_PASSWORD") != null) { + String username = System.getenv("MVNW_USERNAME"); + char[] password = System.getenv("MVNW_PASSWORD").toCharArray(); + Authenticator.setDefault(new Authenticator() { + @Override + protected PasswordAuthentication getPasswordAuthentication() { + return new PasswordAuthentication(username, password); + } + }); + } + URL website = new URL(urlString); + ReadableByteChannel rbc; + rbc = Channels.newChannel(website.openStream()); + FileOutputStream fos = new FileOutputStream(destination); + fos.getChannel().transferFrom(rbc, 0, Long.MAX_VALUE); + fos.close(); + rbc.close(); + } + +} diff --git a/.mvn/wrapper/maven-wrapper.properties b/.mvn/wrapper/maven-wrapper.properties new file mode 100644 index 0000000..05c741e --- /dev/null +++ b/.mvn/wrapper/maven-wrapper.properties @@ -0,0 +1,2 @@ +distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.6.0/apache-maven-3.6.0-bin.zip +wrapperUrl=https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.5.4/maven-wrapper-0.5.4.jar diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..7fb7a88 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,8 @@ +dist: xenial +language: java +jdk: openjdk8 +cache: + directories: + - $HOME/.m2 +install: mvn install -DskipTests=true -Dmaven.javadoc.skip=true -B -V +script: mvn test -B diff --git a/LICENSE b/LICENSE new file mode 100755 index 0000000..0cfd14b --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2017 - Dremio Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/NOTICE b/NOTICE new file mode 100755 index 0000000..d2b0ae2 --- /dev/null +++ b/NOTICE @@ -0,0 +1,6 @@ +Dremio +Copyright 2015-2017 Dremio Corporation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + diff --git a/mvnw b/mvnw new file mode 100755 index 0000000..35ff643 --- /dev/null +++ b/mvnw @@ -0,0 +1,310 @@ +#!/bin/sh +# ---------------------------------------------------------------------------- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ---------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------- +# Maven2 Start Up Batch script +# +# Required ENV vars: +# ------------------ +# JAVA_HOME - location of a JDK home dir +# +# Optional ENV vars +# ----------------- +# M2_HOME - location of maven2's installed home dir +# MAVEN_OPTS - parameters passed to the Java VM when running Maven +# e.g. to debug Maven itself, use +# set MAVEN_OPTS=-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=y,address=8000 +# MAVEN_SKIP_RC - flag to disable loading of mavenrc files +# ---------------------------------------------------------------------------- + +if [ -z "$MAVEN_SKIP_RC" ] ; then + + if [ -f /etc/mavenrc ] ; then + . /etc/mavenrc + fi + + if [ -f "$HOME/.mavenrc" ] ; then + . "$HOME/.mavenrc" + fi + +fi + +# OS specific support. $var _must_ be set to either true or false. +cygwin=false; +darwin=false; +mingw=false +case "`uname`" in + CYGWIN*) cygwin=true ;; + MINGW*) mingw=true;; + Darwin*) darwin=true + # Use /usr/libexec/java_home if available, otherwise fall back to /Library/Java/Home + # See https://developer.apple.com/library/mac/qa/qa1170/_index.html + if [ -z "$JAVA_HOME" ]; then + if [ -x "/usr/libexec/java_home" ]; then + export JAVA_HOME="`/usr/libexec/java_home`" + else + export JAVA_HOME="/Library/Java/Home" + fi + fi + ;; +esac + +if [ -z "$JAVA_HOME" ] ; then + if [ -r /etc/gentoo-release ] ; then + JAVA_HOME=`java-config --jre-home` + fi +fi + +if [ -z "$M2_HOME" ] ; then + ## resolve links - $0 may be a link to maven's home + PRG="$0" + + # need this for relative symlinks + while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG="`dirname "$PRG"`/$link" + fi + done + + saveddir=`pwd` + + M2_HOME=`dirname "$PRG"`/.. + + # make it fully qualified + M2_HOME=`cd "$M2_HOME" && pwd` + + cd "$saveddir" + # echo Using m2 at $M2_HOME +fi + +# For Cygwin, ensure paths are in UNIX format before anything is touched +if $cygwin ; then + [ -n "$M2_HOME" ] && + M2_HOME=`cygpath --unix "$M2_HOME"` + [ -n "$JAVA_HOME" ] && + JAVA_HOME=`cygpath --unix "$JAVA_HOME"` + [ -n "$CLASSPATH" ] && + CLASSPATH=`cygpath --path --unix "$CLASSPATH"` +fi + +# For Mingw, ensure paths are in UNIX format before anything is touched +if $mingw ; then + [ -n "$M2_HOME" ] && + M2_HOME="`(cd "$M2_HOME"; pwd)`" + [ -n "$JAVA_HOME" ] && + JAVA_HOME="`(cd "$JAVA_HOME"; pwd)`" +fi + +if [ -z "$JAVA_HOME" ]; then + javaExecutable="`which javac`" + if [ -n "$javaExecutable" ] && ! [ "`expr \"$javaExecutable\" : '\([^ ]*\)'`" = "no" ]; then + # readlink(1) is not available as standard on Solaris 10. + readLink=`which readlink` + if [ ! `expr "$readLink" : '\([^ ]*\)'` = "no" ]; then + if $darwin ; then + javaHome="`dirname \"$javaExecutable\"`" + javaExecutable="`cd \"$javaHome\" && pwd -P`/javac" + else + javaExecutable="`readlink -f \"$javaExecutable\"`" + fi + javaHome="`dirname \"$javaExecutable\"`" + javaHome=`expr "$javaHome" : '\(.*\)/bin'` + JAVA_HOME="$javaHome" + export JAVA_HOME + fi + fi +fi + +if [ -z "$JAVACMD" ] ; then + if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + else + JAVACMD="`which java`" + fi +fi + +if [ ! -x "$JAVACMD" ] ; then + echo "Error: JAVA_HOME is not defined correctly." >&2 + echo " We cannot execute $JAVACMD" >&2 + exit 1 +fi + +if [ -z "$JAVA_HOME" ] ; then + echo "Warning: JAVA_HOME environment variable is not set." +fi + +CLASSWORLDS_LAUNCHER=org.codehaus.plexus.classworlds.launcher.Launcher + +# traverses directory structure from process work directory to filesystem root +# first directory with .mvn subdirectory is considered project base directory +find_maven_basedir() { + + if [ -z "$1" ] + then + echo "Path not specified to find_maven_basedir" + return 1 + fi + + basedir="$1" + wdir="$1" + while [ "$wdir" != '/' ] ; do + if [ -d "$wdir"/.mvn ] ; then + basedir=$wdir + break + fi + # workaround for JBEAP-8937 (on Solaris 10/Sparc) + if [ -d "${wdir}" ]; then + wdir=`cd "$wdir/.."; pwd` + fi + # end of workaround + done + echo "${basedir}" +} + +# concatenates all lines of a file +concat_lines() { + if [ -f "$1" ]; then + echo "$(tr -s '\n' ' ' < "$1")" + fi +} + +BASE_DIR=`find_maven_basedir "$(pwd)"` +if [ -z "$BASE_DIR" ]; then + exit 1; +fi + +########################################################################################## +# Extension to allow automatically downloading the maven-wrapper.jar from Maven-central +# This allows using the maven wrapper in projects that prohibit checking in binary data. +########################################################################################## +if [ -r "$BASE_DIR/.mvn/wrapper/maven-wrapper.jar" ]; then + if [ "$MVNW_VERBOSE" = true ]; then + echo "Found .mvn/wrapper/maven-wrapper.jar" + fi +else + if [ "$MVNW_VERBOSE" = true ]; then + echo "Couldn't find .mvn/wrapper/maven-wrapper.jar, downloading it ..." + fi + if [ -n "$MVNW_REPOURL" ]; then + jarUrl="$MVNW_REPOURL/io/takari/maven-wrapper/0.5.4/maven-wrapper-0.5.4.jar" + else + jarUrl="https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.5.4/maven-wrapper-0.5.4.jar" + fi + while IFS="=" read key value; do + case "$key" in (wrapperUrl) jarUrl="$value"; break ;; + esac + done < "$BASE_DIR/.mvn/wrapper/maven-wrapper.properties" + if [ "$MVNW_VERBOSE" = true ]; then + echo "Downloading from: $jarUrl" + fi + wrapperJarPath="$BASE_DIR/.mvn/wrapper/maven-wrapper.jar" + if $cygwin; then + wrapperJarPath=`cygpath --path --windows "$wrapperJarPath"` + fi + + if command -v wget > /dev/null; then + if [ "$MVNW_VERBOSE" = true ]; then + echo "Found wget ... using wget" + fi + if [ -z "$MVNW_USERNAME" ] || [ -z "$MVNW_PASSWORD" ]; then + wget "$jarUrl" -O "$wrapperJarPath" + else + wget --http-user=$MVNW_USERNAME --http-password=$MVNW_PASSWORD "$jarUrl" -O "$wrapperJarPath" + fi + elif command -v curl > /dev/null; then + if [ "$MVNW_VERBOSE" = true ]; then + echo "Found curl ... using curl" + fi + if [ -z "$MVNW_USERNAME" ] || [ -z "$MVNW_PASSWORD" ]; then + curl -o "$wrapperJarPath" "$jarUrl" -f + else + curl --user $MVNW_USERNAME:$MVNW_PASSWORD -o "$wrapperJarPath" "$jarUrl" -f + fi + + else + if [ "$MVNW_VERBOSE" = true ]; then + echo "Falling back to using Java to download" + fi + javaClass="$BASE_DIR/.mvn/wrapper/MavenWrapperDownloader.java" + # For Cygwin, switch paths to Windows format before running javac + if $cygwin; then + javaClass=`cygpath --path --windows "$javaClass"` + fi + if [ -e "$javaClass" ]; then + if [ ! -e "$BASE_DIR/.mvn/wrapper/MavenWrapperDownloader.class" ]; then + if [ "$MVNW_VERBOSE" = true ]; then + echo " - Compiling MavenWrapperDownloader.java ..." + fi + # Compiling the Java class + ("$JAVA_HOME/bin/javac" "$javaClass") + fi + if [ -e "$BASE_DIR/.mvn/wrapper/MavenWrapperDownloader.class" ]; then + # Running the downloader + if [ "$MVNW_VERBOSE" = true ]; then + echo " - Running MavenWrapperDownloader.java ..." + fi + ("$JAVA_HOME/bin/java" -cp .mvn/wrapper MavenWrapperDownloader "$MAVEN_PROJECTBASEDIR") + fi + fi + fi +fi +########################################################################################## +# End of extension +########################################################################################## + +export MAVEN_PROJECTBASEDIR=${MAVEN_BASEDIR:-"$BASE_DIR"} +if [ "$MVNW_VERBOSE" = true ]; then + echo $MAVEN_PROJECTBASEDIR +fi +MAVEN_OPTS="$(concat_lines "$MAVEN_PROJECTBASEDIR/.mvn/jvm.config") $MAVEN_OPTS" + +# For Cygwin, switch paths to Windows format before running java +if $cygwin; then + [ -n "$M2_HOME" ] && + M2_HOME=`cygpath --path --windows "$M2_HOME"` + [ -n "$JAVA_HOME" ] && + JAVA_HOME=`cygpath --path --windows "$JAVA_HOME"` + [ -n "$CLASSPATH" ] && + CLASSPATH=`cygpath --path --windows "$CLASSPATH"` + [ -n "$MAVEN_PROJECTBASEDIR" ] && + MAVEN_PROJECTBASEDIR=`cygpath --path --windows "$MAVEN_PROJECTBASEDIR"` +fi + +# Provide a "standardized" way to retrieve the CLI args that will +# work with both Windows and non-Windows executions. +MAVEN_CMD_LINE_ARGS="$MAVEN_CONFIG $@" +export MAVEN_CMD_LINE_ARGS + +WRAPPER_LAUNCHER=org.apache.maven.wrapper.MavenWrapperMain + +exec "$JAVACMD" \ + $MAVEN_OPTS \ + -classpath "$MAVEN_PROJECTBASEDIR/.mvn/wrapper/maven-wrapper.jar" \ + "-Dmaven.home=${M2_HOME}" "-Dmaven.multiModuleProjectDirectory=${MAVEN_PROJECTBASEDIR}" \ + ${WRAPPER_LAUNCHER} $MAVEN_CONFIG "$@" diff --git a/mvnw.cmd b/mvnw.cmd new file mode 100644 index 0000000..dae46d4 --- /dev/null +++ b/mvnw.cmd @@ -0,0 +1,182 @@ +@REM ---------------------------------------------------------------------------- +@REM Licensed to the Apache Software Foundation (ASF) under one +@REM or more contributor license agreements. See the NOTICE file +@REM distributed with this work for additional information +@REM regarding copyright ownership. The ASF licenses this file +@REM to you under the Apache License, Version 2.0 (the +@REM "License"); you may not use this file except in compliance +@REM with the License. You may obtain a copy of the License at +@REM +@REM http://www.apache.org/licenses/LICENSE-2.0 +@REM +@REM Unless required by applicable law or agreed to in writing, +@REM software distributed under the License is distributed on an +@REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +@REM KIND, either express or implied. See the License for the +@REM specific language governing permissions and limitations +@REM under the License. +@REM ---------------------------------------------------------------------------- + +@REM ---------------------------------------------------------------------------- +@REM Maven2 Start Up Batch script +@REM +@REM Required ENV vars: +@REM JAVA_HOME - location of a JDK home dir +@REM +@REM Optional ENV vars +@REM M2_HOME - location of maven2's installed home dir +@REM MAVEN_BATCH_ECHO - set to 'on' to enable the echoing of the batch commands +@REM MAVEN_BATCH_PAUSE - set to 'on' to wait for a key stroke before ending +@REM MAVEN_OPTS - parameters passed to the Java VM when running Maven +@REM e.g. to debug Maven itself, use +@REM set MAVEN_OPTS=-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=y,address=8000 +@REM MAVEN_SKIP_RC - flag to disable loading of mavenrc files +@REM ---------------------------------------------------------------------------- + +@REM Begin all REM lines with '@' in case MAVEN_BATCH_ECHO is 'on' +@echo off +@REM set title of command window +title %0 +@REM enable echoing by setting MAVEN_BATCH_ECHO to 'on' +@if "%MAVEN_BATCH_ECHO%" == "on" echo %MAVEN_BATCH_ECHO% + +@REM set %HOME% to equivalent of $HOME +if "%HOME%" == "" (set "HOME=%HOMEDRIVE%%HOMEPATH%") + +@REM Execute a user defined script before this one +if not "%MAVEN_SKIP_RC%" == "" goto skipRcPre +@REM check for pre script, once with legacy .bat ending and once with .cmd ending +if exist "%HOME%\mavenrc_pre.bat" call "%HOME%\mavenrc_pre.bat" +if exist "%HOME%\mavenrc_pre.cmd" call "%HOME%\mavenrc_pre.cmd" +:skipRcPre + +@setlocal + +set ERROR_CODE=0 + +@REM To isolate internal variables from possible post scripts, we use another setlocal +@setlocal + +@REM ==== START VALIDATION ==== +if not "%JAVA_HOME%" == "" goto OkJHome + +echo. +echo Error: JAVA_HOME not found in your environment. >&2 +echo Please set the JAVA_HOME variable in your environment to match the >&2 +echo location of your Java installation. >&2 +echo. +goto error + +:OkJHome +if exist "%JAVA_HOME%\bin\java.exe" goto init + +echo. +echo Error: JAVA_HOME is set to an invalid directory. >&2 +echo JAVA_HOME = "%JAVA_HOME%" >&2 +echo Please set the JAVA_HOME variable in your environment to match the >&2 +echo location of your Java installation. >&2 +echo. +goto error + +@REM ==== END VALIDATION ==== + +:init + +@REM Find the project base dir, i.e. the directory that contains the folder ".mvn". +@REM Fallback to current working directory if not found. + +set MAVEN_PROJECTBASEDIR=%MAVEN_BASEDIR% +IF NOT "%MAVEN_PROJECTBASEDIR%"=="" goto endDetectBaseDir + +set EXEC_DIR=%CD% +set WDIR=%EXEC_DIR% +:findBaseDir +IF EXIST "%WDIR%"\.mvn goto baseDirFound +cd .. +IF "%WDIR%"=="%CD%" goto baseDirNotFound +set WDIR=%CD% +goto findBaseDir + +:baseDirFound +set MAVEN_PROJECTBASEDIR=%WDIR% +cd "%EXEC_DIR%" +goto endDetectBaseDir + +:baseDirNotFound +set MAVEN_PROJECTBASEDIR=%EXEC_DIR% +cd "%EXEC_DIR%" + +:endDetectBaseDir + +IF NOT EXIST "%MAVEN_PROJECTBASEDIR%\.mvn\jvm.config" goto endReadAdditionalConfig + +@setlocal EnableExtensions EnableDelayedExpansion +for /F "usebackq delims=" %%a in ("%MAVEN_PROJECTBASEDIR%\.mvn\jvm.config") do set JVM_CONFIG_MAVEN_PROPS=!JVM_CONFIG_MAVEN_PROPS! %%a +@endlocal & set JVM_CONFIG_MAVEN_PROPS=%JVM_CONFIG_MAVEN_PROPS% + +:endReadAdditionalConfig + +SET MAVEN_JAVA_EXE="%JAVA_HOME%\bin\java.exe" +set WRAPPER_JAR="%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.jar" +set WRAPPER_LAUNCHER=org.apache.maven.wrapper.MavenWrapperMain + +set DOWNLOAD_URL="https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.5.4/maven-wrapper-0.5.4.jar" + +FOR /F "tokens=1,2 delims==" %%A IN ("%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.properties") DO ( + IF "%%A"=="wrapperUrl" SET DOWNLOAD_URL=%%B +) + +@REM Extension to allow automatically downloading the maven-wrapper.jar from Maven-central +@REM This allows using the maven wrapper in projects that prohibit checking in binary data. +if exist %WRAPPER_JAR% ( + if "%MVNW_VERBOSE%" == "true" ( + echo Found %WRAPPER_JAR% + ) +) else ( + if not "%MVNW_REPOURL%" == "" ( + SET DOWNLOAD_URL="%MVNW_REPOURL%/io/takari/maven-wrapper/0.5.4/maven-wrapper-0.5.4.jar" + ) + if "%MVNW_VERBOSE%" == "true" ( + echo Couldn't find %WRAPPER_JAR%, downloading it ... + echo Downloading from: %DOWNLOAD_URL% + ) + + powershell -Command "&{"^ + "$webclient = new-object System.Net.WebClient;"^ + "if (-not ([string]::IsNullOrEmpty('%MVNW_USERNAME%') -and [string]::IsNullOrEmpty('%MVNW_PASSWORD%'))) {"^ + "$webclient.Credentials = new-object System.Net.NetworkCredential('%MVNW_USERNAME%', '%MVNW_PASSWORD%');"^ + "}"^ + "[Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12; $webclient.DownloadFile('%DOWNLOAD_URL%', '%WRAPPER_JAR%')"^ + "}" + if "%MVNW_VERBOSE%" == "true" ( + echo Finished downloading %WRAPPER_JAR% + ) +) +@REM End of extension + +@REM Provide a "standardized" way to retrieve the CLI args that will +@REM work with both Windows and non-Windows executions. +set MAVEN_CMD_LINE_ARGS=%* + +%MAVEN_JAVA_EXE% %JVM_CONFIG_MAVEN_PROPS% %MAVEN_OPTS% %MAVEN_DEBUG_OPTS% -classpath %WRAPPER_JAR% "-Dmaven.multiModuleProjectDirectory=%MAVEN_PROJECTBASEDIR%" %WRAPPER_LAUNCHER% %MAVEN_CONFIG% %* +if ERRORLEVEL 1 goto error +goto end + +:error +set ERROR_CODE=1 + +:end +@endlocal & set ERROR_CODE=%ERROR_CODE% + +if not "%MAVEN_SKIP_RC%" == "" goto skipRcPost +@REM check for post script, once with legacy .bat ending and once with .cmd ending +if exist "%HOME%\mavenrc_post.bat" call "%HOME%\mavenrc_post.bat" +if exist "%HOME%\mavenrc_post.cmd" call "%HOME%\mavenrc_post.cmd" +:skipRcPost + +@REM pause the script if MAVEN_BATCH_PAUSE is set to 'on' +if "%MAVEN_BATCH_PAUSE%" == "on" pause + +if "%MAVEN_TERMINATE_CMD%" == "on" exit %ERROR_CODE% + +exit /B %ERROR_CODE% From f60e590111cd50c3ec30866209bf74c7b08ceb6e Mon Sep 17 00:00:00 2001 From: Ryan Murray Date: Tue, 15 Oct 2019 12:05:36 +0100 Subject: [PATCH 10/38] Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 996f6d9..1a8b507 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ Spark source for Flight enabled endpoints ========================================= +[![Build Status](https://travis-ci.org/rymurr/flight-spark-source.svg?branch=master)](https://travis-ci.org/rymurr/flight-spark-source) + This uses the new [Source V2 Interface](https://databricks.com/session/apache-spark-data-source-v2) to connect to [Apache Arrow Flight](https://www.dremio.com/understanding-apache-arrow-flight/) endpoints. It is a prototype of what is possible with Arrow Flight. The prototype has achieved 50x speed up compared to serial jdbc driver and scales with the From 918405c3f38dd117f8e6bc027d1a8874348fb6a1 Mon Sep 17 00:00:00 2001 From: Ryan Murray Date: Mon, 28 Oct 2019 15:42:04 +0000 Subject: [PATCH 11/38] remove dremio specific stuff and bring up to date to 0.15.0 --- .editorconfig | 4 +- .mvn/extensions.xml | 4 +- .mvn/wrapper/MavenWrapperDownloader.java | 4 +- .mvn/wrapper/maven-wrapper.properties | 16 + mvnw | 29 +- mvnw.cmd | 29 +- pom.xml | 774 +++++++++--------- .../dremio/proto/flight/commands/Command.java | 277 ------- .../java/com/dremio/spark/DefaultSource.java | 24 +- .../com/dremio/spark/FlightClientFactory.java | 45 +- .../com/dremio/spark/FlightDataReader.java | 86 +- .../dremio/spark/FlightDataReaderFactory.java | 53 +- .../dremio/spark/FlightDataSourceReader.java | 462 ++++++----- .../com/dremio/spark/FlightSparkContext.java | 103 +-- .../java/com/dremio/spark/TestConnector.java | 71 +- src/test/resources/logback-test.xml | 4 +- 16 files changed, 905 insertions(+), 1080 deletions(-) delete mode 100644 src/main/java/com/dremio/proto/flight/commands/Command.java diff --git a/.editorconfig b/.editorconfig index daa1c54..7cfe605 100644 --- a/.editorconfig +++ b/.editorconfig @@ -1,11 +1,11 @@ # -# Copyright (C) 2017-2019 Dremio Corporation +# Copyright (C) 2019 Ryan Murray # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/.mvn/extensions.xml b/.mvn/extensions.xml index b9a3245..431b4b5 100644 --- a/.mvn/extensions.xml +++ b/.mvn/extensions.xml @@ -1,13 +1,13 @@ - 4.0.0 - - com.dremio - flight-spark-source - 1.0-SNAPSHOT + 4.0.0 - - 3.2.4-201906051751050278-1bcce62 - 0.14.0-SNAPSHOT - 2.4.2 - 1.7.25 - 1.4.4 - - - - - src/main/resources - true - - + com.dremio + flight-spark-source + 1.0-SNAPSHOT - - - org.apache.maven.plugins - maven-checkstyle-plugin - - src/main/checkstyle/checkstyle-config.xml - src/main/checkstyle/checkstyle-suppressions.xml - - - - com.mycila - license-maven-plugin - 3.0 - - - Copyright (C) ${project.inceptionYear} ${owner} + + 0.15.0 + 2.4.2 + 1.7.25 + 1.4.4 + + + + + kr.motd.maven + os-maven-plugin + 1.5.0.Final + + - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at + + + src/main/resources + true + + - http://www.apache.org/licenses/LICENSE-2.0 + + + org.apache.maven.plugins + maven-checkstyle-plugin + 3.1.0 + + src/main/checkstyle/checkstyle-config.xml + src/main/checkstyle/checkstyle-suppressions.xml + + + + com.mycila + license-maven-plugin + 3.0 + + +Copyright (C) ${project.inceptionYear} ${owner} - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - - Ryan Murray - 2019 - - - 2019 - - true - false +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - - src/** - * - **/.mvn/** - - - - **/*~ - **/#*# - **/.#* - **/%*% - **/._* - **/.repository/** - **/CVS - **/CVS/** - **/.cvsignore - **/RCS - **/RCS/** - **/SCCS - **/SCCS/** - **/vssver.scc - **/.svn - **/.svn/** - **/.arch-ids - **/.arch-ids/** - **/.bzr - **/.bzr/** - **/.MySCMServerInfo - **/.DS_Store - **/.metadata - **/.metadata/** - **/.hg - **/.hg/** - **/.hgignore - **/.git - **/.git/** - **/.gitignore - **/.gitmodules - **/BitKeeper - **/BitKeeper/** - **/ChangeSet - **/ChangeSet/** - **/_darcs - **/_darcs/** - **/.darcsrepo - **/.darcsrepo/** - **/-darcs-backup* - **/.darcs-temp-mail - - **/test-output/** - **/release.properties - **/dependency-reduced-pom.xml - **/release-pom.xml - **/pom.xml.releaseBackup - **/cobertura.ser - **/.clover/** - **/.classpath - **/.project - **/.settings/** - **/*.iml - **/*.ipr - **/*.iws - .idea/** - **/nb-configuration.xml - **/MANIFEST.MF - **/*.jpg - **/*.png - **/*.gif - **/*.ico - **/*.bmp - **/*.tiff - **/*.tif - **/*.cr2 - **/*.xcf - **/*.class - **/*.exe - **/*.dll - **/*.so - **/*.md5 - **/*.sha1 - **/*.jar - **/*.zip - **/*.rar - **/*.tar - **/*.tar.gz - **/*.tar.bz2 - **/*.gz - **/*.xls - **/META-INF/services/** - **/*.md - **/*.xls - **/*.doc - **/*.odt - **/*.ods - **/*.pdf - **/.travis.yml - **/*.swf - **/*.json - - **/*.eot - **/*.ttf - **/*.woff - **/*.xlsx - **/*.docx - **/*.ppt - **/*.pptx - **/*.patch - +http://www.apache.org/licenses/LICENSE-2.0 - - **/*.log - **/*.txt - **/*.csv - **/*.tsv - **/*.parquet - **/*.jks - **/*.nonformat - **/*.gzip - **/*.k - **/*.q - **/*.dat +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + + Ryan Murray + 2019 + + + 2019 + + true + false - - **/Jenkinsfile - **/LICENSE - **/NOTICE - **/postinstall - **/.babelrc - **/.checkstyle - **/.eslintcache - **/.eslintignore - **/.eslintrc - **/git.properties - **/pom.xml.versionsBackup - **/q - **/c.java + + src/** + * + **/.mvn/** + + + + **/*~ + **/#*# + **/.#* + **/%*% + **/._* + **/.repository/** + **/CVS + **/CVS/** + **/.cvsignore + **/RCS + **/RCS/** + **/SCCS + **/SCCS/** + **/vssver.scc + **/.svn + **/.svn/** + **/.arch-ids + **/.arch-ids/** + **/.bzr + **/.bzr/** + **/.MySCMServerInfo + **/.DS_Store + **/.metadata + **/.metadata/** + **/.hg + **/.hg/** + **/.hgignore + **/.git + **/.git/** + **/.gitignore + **/.gitmodules + **/BitKeeper + **/BitKeeper/** + **/ChangeSet + **/ChangeSet/** + **/_darcs + **/_darcs/** + **/.darcsrepo + **/.darcsrepo/** + **/-darcs-backup* + **/.darcs-temp-mail + + **/test-output/** + **/release.properties + **/dependency-reduced-pom.xml + **/release-pom.xml + **/pom.xml.releaseBackup + **/cobertura.ser + **/.clover/** + **/.classpath + **/.project + **/.settings/** + **/*.iml + **/*.ipr + **/*.iws + .idea/** + **/nb-configuration.xml + **/MANIFEST.MF + **/*.jpg + **/*.png + **/*.gif + **/*.ico + **/*.bmp + **/*.tiff + **/*.tif + **/*.cr2 + **/*.xcf + **/*.class + **/*.exe + **/*.dll + **/*.so + **/*.md5 + **/*.sha1 + **/*.jar + **/*.zip + **/*.rar + **/*.tar + **/*.tar.gz + **/*.tar.bz2 + **/*.gz + **/*.xls + **/META-INF/services/** + **/*.md + **/*.xls + **/*.doc + **/*.odt + **/*.ods + **/*.pdf + **/.travis.yml + **/*.swf + **/*.json + + **/*.eot + **/*.ttf + **/*.woff + **/*.xlsx + **/*.docx + **/*.ppt + **/*.pptx + **/*.patch + - - **/node_modules/** - **/.idea/** - **/db/** - - - SLASHSTAR_STYLE - DOUBLEDASHES_STYLE - DOUBLESLASH_STYLE - DOUBLESLASH_STYLE - DOUBLESLASH_STYLE - SLASHSTAR_STYLE - SLASHSTAR_STYLE - SLASHSTAR_STYLE - SLASHSTAR_STYLE - SCRIPT_STYLE - SCRIPT_STYLE - SCRIPT_STYLE - DOUBLEDASHES_STYLE - SCRIPT_STYLE - SLASHSTAR_STYLE - SCRIPT_STYLE - SCRIPT_STYLE - SCRIPT_STYLE - XML_STYLE - SCRIPT_STYLE - - - - - default-cli - - format - - - - verify-license-headers - verify - - check - - - - - - maven-enforcer-plugin - - - avoid_bad_dependencies - verify - - enforce - - - - - - commons-logging - javax.servlet:servlet-api - org.mortbay.jetty:servlet-api - org.mortbay.jetty:servlet-api-2.5 - log4j:log4j - - - - - - - - - org.apache.maven.plugins - maven-compiler-plugin - - 8 - 8 - - - - - - - - - - - - org.apache.spark - spark-core_2.11 - ${spark.version} - - - org.slf4j - slf4j-log4j12 - - - commons-logging - commons-logging - - - log4j - log4j - - - - - org.apache.spark - spark-sql_2.11 - ${spark.version} - - - org.slf4j - slf4j-log4j12 - - - commons-logging - commons-logging - - - log4j - log4j - - - javax.servlet - servlet-api - - - - - org.apache.arrow - arrow-flight - ${arrow.version} - shaded - - - io.protostuff - protostuff-core - ${protostuff.version} - + + **/*.log + **/*.txt + **/*.csv + **/*.tsv + **/*.parquet + **/*.jks + **/*.nonformat + **/*.gzip + **/*.k + **/*.q + **/*.dat - - io.protostuff - protostuff-collectionschema - ${protostuff.version} - + + **/Jenkinsfile + **/LICENSE + **/NOTICE + **/postinstall + **/.babelrc + **/.checkstyle + **/.eslintcache + **/.eslintignore + **/.eslintrc + **/git.properties + **/pom.xml.versionsBackup + **/q + **/c.java - - io.protostuff - protostuff-runtime - ${protostuff.version} - + + **/node_modules/** + **/.idea/** + **/db/** + + + SLASHSTAR_STYLE + DOUBLEDASHES_STYLE + DOUBLESLASH_STYLE + DOUBLESLASH_STYLE + DOUBLESLASH_STYLE + SLASHSTAR_STYLE + SLASHSTAR_STYLE + SLASHSTAR_STYLE + SLASHSTAR_STYLE + SCRIPT_STYLE + SCRIPT_STYLE + SCRIPT_STYLE + DOUBLEDASHES_STYLE + SCRIPT_STYLE + SLASHSTAR_STYLE + SCRIPT_STYLE + SCRIPT_STYLE + SCRIPT_STYLE + XML_STYLE + SCRIPT_STYLE + + + + + default-cli + + format + + + + verify-license-headers + verify + + check + + + + + + maven-enforcer-plugin + 1.4.1 + + + avoid_bad_dependencies + verify + + enforce + + + + + + commons-logging + javax.servlet:servlet-api + org.mortbay.jetty:servlet-api + org.mortbay.jetty:servlet-api-2.5 + log4j:log4j + + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.8.1 + + 8 + 8 + + + + + + + + + + + + org.apache.spark + spark-core_2.11 + ${spark.version} + + + org.slf4j + slf4j-log4j12 + + + commons-logging + commons-logging + + + log4j + log4j + + + + + org.apache.spark + spark-sql_2.11 + ${spark.version} + + + org.slf4j + slf4j-log4j12 + + + commons-logging + commons-logging + + + log4j + log4j + + + javax.servlet + servlet-api + + + + + org.apache.arrow + arrow-flight + ${arrow.version} + shaded + - - io.protostuff - protostuff-api - ${protostuff.version} - - - - - - - - - - - - - - org.slf4j - jul-to-slf4j - ${dep.slf4j.version} - test - + + org.slf4j + jul-to-slf4j + ${dep.slf4j.version} + test + - - org.slf4j - jcl-over-slf4j - ${dep.slf4j.version} - test - + + org.slf4j + jcl-over-slf4j + ${dep.slf4j.version} + test + - - org.slf4j - log4j-over-slf4j - ${dep.slf4j.version} - test - - - ch.qos.logback - logback-classic - 1.2.3 - test - - - de.huxhorn.lilith - de.huxhorn.lilith.logback.appender.multiplex-classic - 8.2.0 - test - - - junit - junit - 4.11 - test - - + + org.slf4j + log4j-over-slf4j + ${dep.slf4j.version} + test + + + ch.qos.logback + logback-classic + 1.2.3 + test + + + de.huxhorn.lilith + de.huxhorn.lilith.logback.appender.multiplex-classic + 8.2.0 + test + + + junit + junit + 4.11 + test + + - \ No newline at end of file + diff --git a/src/main/java/com/dremio/proto/flight/commands/Command.java b/src/main/java/com/dremio/proto/flight/commands/Command.java deleted file mode 100644 index 3bc1376..0000000 --- a/src/main/java/com/dremio/proto/flight/commands/Command.java +++ /dev/null @@ -1,277 +0,0 @@ -// Generated by http://code.google.com/p/protostuff/ ... DO NOT EDIT! -// Generated from protobuf - -package com.dremio.proto.flight.commands; - -import javax.annotation.Generated; -import java.io.Externalizable; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import java.util.Objects; - -import io.protostuff.ByteString; -import io.protostuff.GraphIOUtil; -import io.protostuff.Input; -import io.protostuff.Message; -import io.protostuff.Output; -import io.protostuff.Schema; - -import io.protostuff.UninitializedMessageException; -@Generated("dremio_java_bean.java.stg") -public final class Command implements Externalizable, Message, Schema -{ - - public static Schema getSchema() - { - return DEFAULT_INSTANCE; - } - - public static Command getDefaultInstance() - { - return DEFAULT_INSTANCE; - } - - static final Command DEFAULT_INSTANCE = new Command(); - - - private String - query; - private Boolean - parallel; - private Boolean - coalesce; - private ByteString - ticket; - - public Command() - { - - } - - public Command( - String query, - Boolean parallel, - Boolean coalesce, - ByteString ticket - ) - { - this.query = query; - this.parallel = parallel; - this.coalesce = coalesce; - this.ticket = ticket; - } - - // getters and setters - - // query - public String - getQuery() - { - return query; - } - - public Command setQuery(String - query) - { - this.query = query; - return this; - } - - // parallel - public Boolean - getParallel() - { - return parallel; - } - - public Command setParallel(Boolean - parallel) - { - this.parallel = parallel; - return this; - } - - // coalesce - public Boolean - getCoalesce() - { - return coalesce; - } - - public Command setCoalesce(Boolean - coalesce) - { - this.coalesce = coalesce; - return this; - } - - // ticket - public ByteString - getTicket() - { - return ticket; - } - - public Command setTicket(ByteString - ticket) - { - this.ticket = ticket; - return this; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null || this.getClass() != obj.getClass()) { - return false; - } - final Command that = (Command) obj; - return - Objects.equals(this.query, that.query) && - Objects.equals(this.parallel, that.parallel) && - Objects.equals(this.coalesce, that.coalesce) && - Objects.equals(this.ticket, that.ticket); - } - - @Override - public int hashCode() { - return Objects.hash(query, parallel, coalesce, ticket); - } - - @Override - public String toString() { - return "Command{" + - "query=" + query + - ", parallel=" + parallel + - ", coalesce=" + coalesce + - ", ticket=" + ticket + - '}'; - } - // java serialization - - public void readExternal(ObjectInput in) throws IOException - { - GraphIOUtil.mergeDelimitedFrom(in, this, this); - } - - public void writeExternal(ObjectOutput out) throws IOException - { - GraphIOUtil.writeDelimitedTo(out, this, this); - } - - // message method - - public Schema cachedSchema() - { - return DEFAULT_INSTANCE; - } - - // schema methods - - public Command newMessage() - { - return new Command(); - } - - public Class typeClass() - { - return Command.class; - } - - public String messageName() - { - return Command.class.getSimpleName(); - } - - public String messageFullName() - { - return Command.class.getName(); - } - - public boolean isInitialized(Command message) - { - return - message.query != null - && message.parallel != null - && message.coalesce != null - && message.ticket != null; - } - - public void mergeFrom(Input input, Command message) throws IOException - { - for(int number = input.readFieldNumber(this);; number = input.readFieldNumber(this)) - { - switch(number) - { - case 0: - return; - case 1: - message.query = input.readString(); - break; - case 2: - message.parallel = input.readBool(); - break; - case 3: - message.coalesce = input.readBool(); - break; - case 4: - message.ticket = input.readBytes(); - break; - default: - input.handleUnknownField(number, this); - } - } - } - - - public void writeTo(Output output, Command message) throws IOException - { - if(message.query == null) - throw new UninitializedMessageException(message); - output.writeString(1, message.query, false); - - if(message.parallel == null) - throw new UninitializedMessageException(message); - output.writeBool(2, message.parallel, false); - - if(message.coalesce == null) - throw new UninitializedMessageException(message); - output.writeBool(3, message.coalesce, false); - - if(message.ticket == null) - throw new UninitializedMessageException(message); - output.writeBytes(4, message.ticket, false); - } - - public String getFieldName(int number) - { - switch(number) - { - case 1: return "query"; - case 2: return "parallel"; - case 3: return "coalesce"; - case 4: return "ticket"; - default: return null; - } - } - - public int getFieldNumber(String name) - { - final Integer number = __fieldMap.get(name); - return number == null ? 0 : number.intValue(); - } - - private static final java.util.HashMap __fieldMap = new java.util.HashMap(); - static - { - __fieldMap.put("query", 1); - __fieldMap.put("parallel", 2); - __fieldMap.put("coalesce", 3); - __fieldMap.put("ticket", 4); - } - - -} diff --git a/src/main/java/com/dremio/spark/DefaultSource.java b/src/main/java/com/dremio/spark/DefaultSource.java index 2c82006..f7af1da 100644 --- a/src/main/java/com/dremio/spark/DefaultSource.java +++ b/src/main/java/com/dremio/spark/DefaultSource.java @@ -1,3 +1,18 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.dremio.spark; import org.apache.arrow.memory.RootAllocator; @@ -7,8 +22,9 @@ import org.apache.spark.sql.sources.v2.reader.DataSourceReader; public class DefaultSource implements DataSourceV2, ReadSupport { - private final RootAllocator rootAllocator = new RootAllocator(); - public DataSourceReader createReader(DataSourceOptions dataSourceOptions) { - return new FlightDataSourceReader(dataSourceOptions, rootAllocator.newChildAllocator(dataSourceOptions.toString(), 0, rootAllocator.getLimit())); - } + private final RootAllocator rootAllocator = new RootAllocator(); + + public DataSourceReader createReader(DataSourceOptions dataSourceOptions) { + return new FlightDataSourceReader(dataSourceOptions, rootAllocator.newChildAllocator(dataSourceOptions.toString(), 0, rootAllocator.getLimit())); + } } diff --git a/src/main/java/com/dremio/spark/FlightClientFactory.java b/src/main/java/com/dremio/spark/FlightClientFactory.java index f95be42..11ea7f4 100644 --- a/src/main/java/com/dremio/spark/FlightClientFactory.java +++ b/src/main/java/com/dremio/spark/FlightClientFactory.java @@ -1,3 +1,18 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.dremio.spark; import org.apache.arrow.flight.FlightClient; @@ -5,23 +20,23 @@ import org.apache.arrow.memory.BufferAllocator; public class FlightClientFactory { - private BufferAllocator allocator; - private Location defaultLocation; - private final String username; - private final String password; + private BufferAllocator allocator; + private Location defaultLocation; + private final String username; + private final String password; - public FlightClientFactory(BufferAllocator allocator, Location defaultLocation, String username, String password) { - this.allocator = allocator; - this.defaultLocation = defaultLocation; - this.username = username; - this.password = password; - } + public FlightClientFactory(BufferAllocator allocator, Location defaultLocation, String username, String password) { + this.allocator = allocator; + this.defaultLocation = defaultLocation; + this.username = username; + this.password = password; + } - public FlightClient apply() { - FlightClient client = FlightClient.builder(allocator, defaultLocation).build(); - client.authenticateBasic(username, password); - return client; + public FlightClient apply() { + FlightClient client = FlightClient.builder(allocator, defaultLocation).build(); + client.authenticateBasic(username, password); + return client; - } + } } diff --git a/src/main/java/com/dremio/spark/FlightDataReader.java b/src/main/java/com/dremio/spark/FlightDataReader.java index e28852f..2115837 100644 --- a/src/main/java/com/dremio/spark/FlightDataReader.java +++ b/src/main/java/com/dremio/spark/FlightDataReader.java @@ -1,5 +1,22 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.dremio.spark; +import java.io.IOException; + import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; @@ -11,43 +28,40 @@ import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; +public class FlightDataReader implements InputPartitionReader { + private final FlightClient client; + private final FlightStream stream; + private final BufferAllocator allocator; -import java.io.IOException; + public FlightDataReader( + byte[] ticket, + String defaultHost, + int defaultPort) { + this.allocator = new RootAllocator(); + client = FlightClient.builder(this.allocator, Location.forGrpcInsecure(defaultHost, defaultPort)).build(); //todo multiple locations + client.authenticateBasic("dremio", "dremio123"); + stream = client.getStream(new Ticket(ticket)); + } -public class FlightDataReader implements InputPartitionReader { - private final FlightClient client; - private final FlightStream stream; - private final BufferAllocator allocator; - - public FlightDataReader( - byte[] ticket, - String defaultHost, - int defaultPort) { - this.allocator = new RootAllocator(); - client = FlightClient.builder(this.allocator, Location.forGrpcInsecure(defaultHost, defaultPort)).build(); //todo multiple locations - client.authenticateBasic("dremio", "dremio123"); - stream = client.getStream(new Ticket(ticket)); - } - - @Override - public boolean next() throws IOException { - return stream.next(); - } - - @Override - public ColumnarBatch get() { - ColumnarBatch batch = new ColumnarBatch( - stream.getRoot().getFieldVectors() - .stream() - .map(ArrowColumnVector::new) - .toArray(ColumnVector[]::new) - ); - batch.setNumRows(stream.getRoot().getRowCount()); - return batch; - } - - @Override - public void close() throws IOException { + @Override + public boolean next() throws IOException { + return stream.next(); + } + + @Override + public ColumnarBatch get() { + ColumnarBatch batch = new ColumnarBatch( + stream.getRoot().getFieldVectors() + .stream() + .map(ArrowColumnVector::new) + .toArray(ColumnVector[]::new) + ); + batch.setNumRows(stream.getRoot().getRowCount()); + return batch; + } + + @Override + public void close() throws IOException { // try { // client.close(); @@ -56,5 +70,5 @@ public void close() throws IOException { // } catch (Exception e) { // throw new IOException(e); // } - } + } } diff --git a/src/main/java/com/dremio/spark/FlightDataReaderFactory.java b/src/main/java/com/dremio/spark/FlightDataReaderFactory.java index 6a50db6..77c9b65 100644 --- a/src/main/java/com/dremio/spark/FlightDataReaderFactory.java +++ b/src/main/java/com/dremio/spark/FlightDataReaderFactory.java @@ -1,3 +1,18 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.dremio.spark; import org.apache.spark.sql.sources.v2.reader.InputPartition; @@ -6,27 +21,27 @@ public class FlightDataReaderFactory implements InputPartition { - private byte[] ticket; - private final String defaultHost; - private final int defaultPort; + private byte[] ticket; + private final String defaultHost; + private final int defaultPort; - public FlightDataReaderFactory( - byte[] ticket, - String defaultHost, - int defaultPort) { - this.ticket = ticket; - this.defaultHost = defaultHost; - this.defaultPort = defaultPort; - } + public FlightDataReaderFactory( + byte[] ticket, + String defaultHost, + int defaultPort) { + this.ticket = ticket; + this.defaultHost = defaultHost; + this.defaultPort = defaultPort; + } - @Override - public String[] preferredLocations() { - return new String[]{defaultHost}; - } + @Override + public String[] preferredLocations() { + return new String[]{defaultHost}; + } - @Override - public InputPartitionReader createPartitionReader() { - return new FlightDataReader(ticket, defaultHost, defaultPort); - } + @Override + public InputPartitionReader createPartitionReader() { + return new FlightDataReader(ticket, defaultHost, defaultPort); + } } diff --git a/src/main/java/com/dremio/spark/FlightDataSourceReader.java b/src/main/java/com/dremio/spark/FlightDataSourceReader.java index 2e5ebda..8dd3733 100644 --- a/src/main/java/com/dremio/spark/FlightDataSourceReader.java +++ b/src/main/java/com/dremio/spark/FlightDataSourceReader.java @@ -1,15 +1,32 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.dremio.spark; -import com.dremio.proto.flight.commands.Command; -import com.google.common.base.Joiner; -import com.google.common.collect.Lists; -import io.protostuff.ByteString; -import io.protostuff.LinkedBuffer; -import io.protostuff.ProtostuffIOUtil; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.arrow.flight.Action; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.SchemaResult; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -31,253 +48,248 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; +import com.google.common.base.Joiner; +import com.google.common.collect.Lists; import scala.collection.JavaConversions; public class FlightDataSourceReader implements SupportsScanColumnarBatch, SupportsPushDownFilters, SupportsPushDownRequiredColumns { - private static final Logger LOGGER = LoggerFactory.getLogger(FlightDataSourceReader.class); - private static final Joiner WHERE_JOINER = Joiner.on(" and "); - private static final Joiner PROJ_JOINER = Joiner.on(", "); - private FlightInfo info; - private FlightDescriptor descriptor; - private StructType schema; - private final LinkedBuffer buffer = LinkedBuffer.allocate(); - private final Location defaultLocation; - private final FlightClientFactory clientFactory; - private final boolean parallel; - private String sql; - private Filter[] pushed; + private static final Logger LOGGER = LoggerFactory.getLogger(FlightDataSourceReader.class); + private static final Joiner WHERE_JOINER = Joiner.on(" and "); + private static final Joiner PROJ_JOINER = Joiner.on(", "); + private SchemaResult info; + private FlightDescriptor descriptor; + private StructType schema; + private final Location defaultLocation; + private final FlightClientFactory clientFactory; + private final boolean parallel; + private String sql; + private Filter[] pushed; - public FlightDataSourceReader(DataSourceOptions dataSourceOptions, BufferAllocator allocator) { - defaultLocation = Location.forGrpcInsecure( - dataSourceOptions.get("host").orElse("localhost"), - dataSourceOptions.getInt("port", 47470) - ); - clientFactory = new FlightClientFactory(allocator, - defaultLocation, - dataSourceOptions.get("username").orElse("anonymous"), - dataSourceOptions.get("password").orElse(null) - ); - parallel = dataSourceOptions.getBoolean("parallel", false); - sql = dataSourceOptions.get("path").orElse(""); - descriptor = getDescriptor(dataSourceOptions.getBoolean("isSql", false), sql); - try (FlightClient client = clientFactory.apply()) { - info = client.getInfo(descriptor); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } + public FlightDataSourceReader(DataSourceOptions dataSourceOptions, BufferAllocator allocator) { + defaultLocation = Location.forGrpcInsecure( + dataSourceOptions.get("host").orElse("localhost"), + dataSourceOptions.getInt("port", 47470) + ); + clientFactory = new FlightClientFactory(allocator, + defaultLocation, + dataSourceOptions.get("username").orElse("anonymous"), + dataSourceOptions.get("password").orElse(null) + ); + parallel = dataSourceOptions.getBoolean("parallel", false); + sql = dataSourceOptions.get("path").orElse(""); + descriptor = getDescriptor(sql); + try (FlightClient client = clientFactory.apply()) { + if (parallel) { + Iterator res = client.doAction(new Action("PARALLEL")); + res.forEachRemaining(Object::toString); + } + info = client.getSchema(descriptor); + } catch (InterruptedException e) { + throw new RuntimeException(e); } + } - private FlightDescriptor getDescriptor(boolean isSql, String path) { - String query = (!isSql) ? ("select * from " + path) : path; - byte[] message = ProtostuffIOUtil.toByteArray(new Command(query , parallel, false, ByteString.EMPTY), Command.getSchema(), buffer); - buffer.clear(); - return FlightDescriptor.command(message); - } + private FlightDescriptor getDescriptor(String path) { + return FlightDescriptor.command(path.getBytes()); + } - private StructType readSchemaImpl() { - StructField[] fields = info.getSchema().getFields().stream() - .map(field -> - new StructField(field.getName(), - sparkFromArrow(field.getFieldType()), - field.isNullable(), - Metadata.empty())) - .toArray(StructField[]::new); - return new StructType(fields); - } + private StructType readSchemaImpl() { + StructField[] fields = info.getSchema().getFields().stream() + .map(field -> + new StructField(field.getName(), + sparkFromArrow(field.getFieldType()), + field.isNullable(), + Metadata.empty())) + .toArray(StructField[]::new); + return new StructType(fields); + } - public StructType readSchema() { - if (schema == null) { - schema = readSchemaImpl(); - } - return schema; + public StructType readSchema() { + if (schema == null) { + schema = readSchemaImpl(); } + return schema; + } - private DataType sparkFromArrow(FieldType fieldType) { - switch (fieldType.getType().getTypeID()) { - case Null: - return DataTypes.NullType; - case Struct: - throw new UnsupportedOperationException("have not implemented Struct type yet"); - case List: - throw new UnsupportedOperationException("have not implemented List type yet"); - case FixedSizeList: - throw new UnsupportedOperationException("have not implemented FixedSizeList type yet"); - case Union: - throw new UnsupportedOperationException("have not implemented Union type yet"); - case Int: - ArrowType.Int intType = (ArrowType.Int) fieldType.getType(); - int bitWidth = intType.getBitWidth(); - if (bitWidth == 8) { - return DataTypes.ByteType; - } else if (bitWidth == 16) { - return DataTypes.ShortType; - } else if (bitWidth == 32) { - return DataTypes.IntegerType; - } else if (bitWidth == 64) { - return DataTypes.LongType; - } - throw new UnsupportedOperationException("unknow int type with bitwidth " + bitWidth); - case FloatingPoint: - ArrowType.FloatingPoint floatType = (ArrowType.FloatingPoint) fieldType.getType(); - FloatingPointPrecision precision = floatType.getPrecision(); - switch (precision) { - case HALF: - case SINGLE: - return DataTypes.FloatType; - case DOUBLE: - return DataTypes.DoubleType; - } - case Utf8: - return DataTypes.StringType; - case Binary: - case FixedSizeBinary: - return DataTypes.BinaryType; - case Bool: - return DataTypes.BooleanType; - case Decimal: - throw new UnsupportedOperationException("have not implemented Decimal type yet"); - case Date: - return DataTypes.DateType; - case Time: - return DataTypes.TimestampType; //note i don't know what this will do! - case Timestamp: - return DataTypes.TimestampType; - case Interval: - return DataTypes.CalendarIntervalType; - case NONE: - return DataTypes.NullType; + private DataType sparkFromArrow(FieldType fieldType) { + switch (fieldType.getType().getTypeID()) { + case Null: + return DataTypes.NullType; + case Struct: + throw new UnsupportedOperationException("have not implemented Struct type yet"); + case List: + throw new UnsupportedOperationException("have not implemented List type yet"); + case FixedSizeList: + throw new UnsupportedOperationException("have not implemented FixedSizeList type yet"); + case Union: + throw new UnsupportedOperationException("have not implemented Union type yet"); + case Int: + ArrowType.Int intType = (ArrowType.Int) fieldType.getType(); + int bitWidth = intType.getBitWidth(); + if (bitWidth == 8) { + return DataTypes.ByteType; + } else if (bitWidth == 16) { + return DataTypes.ShortType; + } else if (bitWidth == 32) { + return DataTypes.IntegerType; + } else if (bitWidth == 64) { + return DataTypes.LongType; } - throw new IllegalStateException("Unexpected value: " + fieldType); - } - - @Override - public List> planBatchInputPartitions() { - if (parallel) { - return planBatchInputPartitionsParallel(); + throw new UnsupportedOperationException("unknow int type with bitwidth " + bitWidth); + case FloatingPoint: + ArrowType.FloatingPoint floatType = (ArrowType.FloatingPoint) fieldType.getType(); + FloatingPointPrecision precision = floatType.getPrecision(); + switch (precision) { + case HALF: + case SINGLE: + return DataTypes.FloatType; + case DOUBLE: + return DataTypes.DoubleType; } - return planBatchInputPartitionsSerial(info); + case Utf8: + return DataTypes.StringType; + case Binary: + case FixedSizeBinary: + return DataTypes.BinaryType; + case Bool: + return DataTypes.BooleanType; + case Decimal: + throw new UnsupportedOperationException("have not implemented Decimal type yet"); + case Date: + return DataTypes.DateType; + case Time: + return DataTypes.TimestampType; //note i don't know what this will do! + case Timestamp: + return DataTypes.TimestampType; + case Interval: + return DataTypes.CalendarIntervalType; + case NONE: + return DataTypes.NullType; } + throw new IllegalStateException("Unexpected value: " + fieldType); + } - private List> planBatchInputPartitionsParallel() { - byte[] message = ProtostuffIOUtil.toByteArray(new Command("", true, true, ByteString.copyFrom(info.getEndpoints().get(0).getTicket().getBytes())), Command.getSchema(), buffer); - buffer.clear(); - try (FlightClient client = clientFactory.apply()) { - FlightInfo info = client.getInfo(FlightDescriptor.command(message)); - return planBatchInputPartitionsSerial(info); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } + @Override + public List> planBatchInputPartitions() { + return planBatchInputPartitionsParallel(); + } - private List> planBatchInputPartitionsSerial(FlightInfo info) { - return info.getEndpoints().stream().map(endpoint -> { - Location location = (endpoint.getLocations().isEmpty()) ? - Location.forGrpcInsecure(defaultLocation.getUri().getHost(), defaultLocation.getUri().getPort()) : - endpoint.getLocations().get(0); - return new FlightDataReaderFactory(endpoint.getTicket().getBytes(), - location.getUri().getHost(), - location.getUri().getPort()); - }).collect(Collectors.toList()); - } + private List> planBatchInputPartitionsParallel() { - @Override - public Filter[] pushFilters(Filter[] filters) { - List notPushed = Lists.newArrayList(); - List pushed = Lists.newArrayList(); - for (Filter filter: filters) { - boolean isPushed = canBePushed(filter); - if (isPushed) { - pushed.add(filter); - } else { - notPushed.add(filter); - } - } - this.pushed = pushed.toArray(new Filter[0]); - if (!pushed.isEmpty()) { - String whereClause = generateWhereClause(pushed); - mergeWhereDescriptors(whereClause); - try (FlightClient client = clientFactory.apply()) { - info = client.getInfo(descriptor); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } - return notPushed.toArray(new Filter[0]); + try (FlightClient client = clientFactory.apply()) { + FlightInfo info = client.getInfo(FlightDescriptor.command(sql.getBytes())); + return planBatchInputPartitionsSerial(info); + } catch (InterruptedException e) { + throw new RuntimeException(e); } + } - private void mergeWhereDescriptors(String whereClause) { - sql = String.format("select * from (%s) where %s", sql, whereClause); - descriptor = getDescriptor(true, sql); - } + private List> planBatchInputPartitionsSerial(FlightInfo info) { + return info.getEndpoints().stream().map(endpoint -> { + Location location = (endpoint.getLocations().isEmpty()) ? + Location.forGrpcInsecure(defaultLocation.getUri().getHost(), defaultLocation.getUri().getPort()) : + endpoint.getLocations().get(0); + return new FlightDataReaderFactory(endpoint.getTicket().getBytes(), + location.getUri().getHost(), + location.getUri().getPort()); + }).collect(Collectors.toList()); + } - private void mergeProjDescriptors(String projClause) { - sql = String.format("select %s from (%s)", projClause, sql); - descriptor = getDescriptor(true, sql); + @Override + public Filter[] pushFilters(Filter[] filters) { + List notPushed = Lists.newArrayList(); + List pushed = Lists.newArrayList(); + for (Filter filter : filters) { + boolean isPushed = canBePushed(filter); + if (isPushed) { + pushed.add(filter); + } else { + notPushed.add(filter); + } } - - private String generateWhereClause(List pushed) { - List filterStr = Lists.newArrayList(); - for (Filter filter: pushed) { - if (filter instanceof IsNotNull) { - filterStr.add(String.format("isnotnull(\"%s\")", ((IsNotNull) filter).attribute())); - } else if (filter instanceof EqualTo){ - filterStr.add(String.format("\"%s\" = %s", ((EqualTo) filter).attribute(), valueToString(((EqualTo) filter).value()))); - } - } - return WHERE_JOINER.join(filterStr); + this.pushed = pushed.toArray(new Filter[0]); + if (!pushed.isEmpty()) { + String whereClause = generateWhereClause(pushed); + mergeWhereDescriptors(whereClause); + try (FlightClient client = clientFactory.apply()) { + info = client.getSchema(descriptor); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } } + return notPushed.toArray(new Filter[0]); + } - private String valueToString(Object value) { - if (value instanceof String) { - return String.format("'%s'", value); - } - return value.toString(); + private void mergeWhereDescriptors(String whereClause) { + sql = String.format("select * from (%s) where %s", sql, whereClause); + descriptor = getDescriptor(sql); + } + + private void mergeProjDescriptors(String projClause) { + sql = String.format("select %s from (%s)", projClause, sql); + descriptor = getDescriptor(sql); + } + + private String generateWhereClause(List pushed) { + List filterStr = Lists.newArrayList(); + for (Filter filter : pushed) { + if (filter instanceof IsNotNull) { + filterStr.add(String.format("isnotnull(\"%s\")", ((IsNotNull) filter).attribute())); + } else if (filter instanceof EqualTo) { + filterStr.add(String.format("\"%s\" = %s", ((EqualTo) filter).attribute(), valueToString(((EqualTo) filter).value()))); + } } + return WHERE_JOINER.join(filterStr); + } - private boolean canBePushed(Filter filter) { - if (filter instanceof IsNotNull) { - return true; - } else if (filter instanceof EqualTo){ - return true; - } - LOGGER.error("Cant push filter of type " + filter.toString()); - return false; + private String valueToString(Object value) { + if (value instanceof String) { + return String.format("'%s'", value); } + return value.toString(); + } - @Override - public Filter[] pushedFilters() { - return pushed; + private boolean canBePushed(Filter filter) { + if (filter instanceof IsNotNull) { + return true; + } else if (filter instanceof EqualTo) { + return true; } + LOGGER.error("Cant push filter of type " + filter.toString()); + return false; + } - @Override - public void pruneColumns(StructType requiredSchema) { - if (requiredSchema.toSeq().isEmpty()) { - return; - } - StructType schema = readSchema(); - List fields = Lists.newArrayList(); - List fieldsLeft = Lists.newArrayList(); - Map fieldNames = JavaConversions.seqAsJavaList(schema.toSeq()).stream().collect(Collectors.toMap(StructField::name, f->f)); - for (StructField field: JavaConversions.seqAsJavaList(requiredSchema.toSeq())) { - String name = field.name(); - StructField f = fieldNames.remove(name); - if (f != null) { - fields.add(String.format("\"%s\"",name)); - fieldsLeft.add(f); - } - } - if (!fieldNames.isEmpty()) { - this.schema = new StructType(fieldsLeft.toArray(new StructField[0])); - mergeProjDescriptors(PROJ_JOINER.join(fields)); - try (FlightClient client = clientFactory.apply()) { - info = client.getInfo(descriptor); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } + @Override + public Filter[] pushedFilters() { + return pushed; + } + + @Override + public void pruneColumns(StructType requiredSchema) { + if (requiredSchema.toSeq().isEmpty()) { + return; + } + StructType schema = readSchema(); + List fields = Lists.newArrayList(); + List fieldsLeft = Lists.newArrayList(); + Map fieldNames = JavaConversions.seqAsJavaList(schema.toSeq()).stream().collect(Collectors.toMap(StructField::name, f -> f)); + for (StructField field : JavaConversions.seqAsJavaList(requiredSchema.toSeq())) { + String name = field.name(); + StructField f = fieldNames.remove(name); + if (f != null) { + fields.add(String.format("\"%s\"", name)); + fieldsLeft.add(f); + } + } + if (!fieldNames.isEmpty()) { + this.schema = new StructType(fieldsLeft.toArray(new StructField[0])); + mergeProjDescriptors(PROJ_JOINER.join(fields)); + try (FlightClient client = clientFactory.apply()) { + info = client.getSchema(descriptor); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } } + } } diff --git a/src/main/java/com/dremio/spark/FlightSparkContext.java b/src/main/java/com/dremio/spark/FlightSparkContext.java index b7c5ae2..ce911e5 100644 --- a/src/main/java/com/dremio/spark/FlightSparkContext.java +++ b/src/main/java/com/dremio/spark/FlightSparkContext.java @@ -1,3 +1,18 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.dremio.spark; import org.apache.spark.SparkConf; @@ -10,56 +25,52 @@ public class FlightSparkContext { - private SparkConf conf; - private final DataFrameReader reader; + private SparkConf conf; + private final DataFrameReader reader; - private FlightSparkContext(SparkContext sc, SparkConf conf) { - SQLContext sqlContext = SQLContext.getOrCreate(sc); - this.conf = conf; - reader = sqlContext.read().format("com.dremio.spark"); - } + private FlightSparkContext(SparkContext sc, SparkConf conf) { + SQLContext sqlContext = SQLContext.getOrCreate(sc); + this.conf = conf; + reader = sqlContext.read().format("com.dremio.spark"); + } - public static FlightSparkContext flightContext(JavaSparkContext sc) { - return new FlightSparkContext(sc.sc(), sc.getConf()); - } + public static FlightSparkContext flightContext(JavaSparkContext sc) { + return new FlightSparkContext(sc.sc(), sc.getConf()); + } - public Dataset read(String s) { - return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) - .option("host", conf.get("spark.flight.endpoint.host")) - .option("username", conf.get("spark.flight.auth.username")) - .option("password", conf.get("spark.flight.auth.password")) - .option("isSql", false) - .option("parallel", false) - .load(s); - } + public Dataset read(String s) { + return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) + .option("host", conf.get("spark.flight.endpoint.host")) + .option("username", conf.get("spark.flight.auth.username")) + .option("password", conf.get("spark.flight.auth.password")) + .option("parallel", false) + .load(s); + } - public Dataset readSql(String s) { - return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) - .option("host", conf.get("spark.flight.endpoint.host")) - .option("username", conf.get("spark.flight.auth.username")) - .option("password", conf.get("spark.flight.auth.password")) - .option("isSql", true) - .option("parallel", false) - .load(s); - } + public Dataset readSql(String s) { + return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) + .option("host", conf.get("spark.flight.endpoint.host")) + .option("username", conf.get("spark.flight.auth.username")) + .option("password", conf.get("spark.flight.auth.password")) + .option("parallel", false) + .load(s); + } - public Dataset read(String s, boolean parallel) { - return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) - .option("host", conf.get("spark.flight.endpoint.host")) - .option("username", conf.get("spark.flight.auth.username")) - .option("password", conf.get("spark.flight.auth.password")) - .option("isSql", false) - .option("parallel", parallel) - .load(s); - } + public Dataset read(String s, boolean parallel) { + return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) + .option("host", conf.get("spark.flight.endpoint.host")) + .option("username", conf.get("spark.flight.auth.username")) + .option("password", conf.get("spark.flight.auth.password")) + .option("parallel", parallel) + .load(s); + } - public Dataset readSql(String s, boolean parallel) { - return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) - .option("host", conf.get("spark.flight.endpoint.host")) - .option("username", conf.get("spark.flight.auth.username")) - .option("password", conf.get("spark.flight.auth.password")) - .option("isSql", true) - .option("parallel", parallel) - .load(s); - } + public Dataset readSql(String s, boolean parallel) { + return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) + .option("host", conf.get("spark.flight.endpoint.host")) + .option("username", conf.get("spark.flight.auth.username")) + .option("password", conf.get("spark.flight.auth.password")) + .option("parallel", parallel) + .load(s); + } } diff --git a/src/test/java/com/dremio/spark/TestConnector.java b/src/test/java/com/dremio/spark/TestConnector.java index 9dcb940..c181eeb 100644 --- a/src/test/java/com/dremio/spark/TestConnector.java +++ b/src/test/java/com/dremio/spark/TestConnector.java @@ -1,3 +1,18 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.dremio.spark; import org.apache.spark.SparkConf; @@ -96,44 +111,44 @@ public void testProject() { @Test public void testParallel() { String easySql = "select * from sys.options"; - String hardSql = "select * from \"@dremio\".test"; - Dataset df = csc.readSql(hardSql, true); +// String hardSql = "select * from \"@dremio\".test"; + Dataset df = csc.readSql(easySql, true); SizeConsumer c = new SizeConsumer(); SizeConsumer c2 = new SizeConsumer(); Dataset dff = df.select("bid", "ask", "symbol").filter(df.col("symbol").equalTo("USDCAD")); dff.toLocalIterator().forEachRemaining(c); long width = c.width; long length = c.length; - csc.readSql(hardSql, true).toLocalIterator().forEachRemaining(c2); + csc.readSql(easySql, true).toLocalIterator().forEachRemaining(c2); long widthOriginal = c2.width; long lengthOriginal = c2.length; Assert.assertTrue(width < widthOriginal); Assert.assertTrue(length < lengthOriginal); } - @Ignore - @Test - public void testSpeed() { - long[] jdbcT = new long[16]; - long[] flightT = new long[16]; - Properties connectionProperties = new Properties(); - connectionProperties.put("user", "dremio"); - connectionProperties.put("password", "dremio123"); - long jdbcC = 0; - long flightC = 0; - for (int i=0;i<4;i++) { - long now = System.currentTimeMillis(); - Dataset jdbc = SQLContext.getOrCreate(sc.sc()).read().jdbc("jdbc:dremio:direct=localhost:31010", "\"@dremio\".sdd", connectionProperties); - jdbcC = jdbc.count(); - long then = System.currentTimeMillis(); - flightC = csc.read("@dremio.sdd").count(); - long andHereWeAre = System.currentTimeMillis(); - jdbcT[i] = then-now; - flightT[i] = andHereWeAre - then; - } - for (int i =0;i<16;i++) { - System.out.println("Trial " + i + ": Flight took " + flightT[i] + " and jdbc took " + jdbcT[i]); - } - System.out.println("Fetched " + jdbcC + " row from jdbc and " + flightC + " from flight"); - } +// @Ignore +// @Test +// public void testSpeed() { +// long[] jdbcT = new long[16]; +// long[] flightT = new long[16]; +// Properties connectionProperties = new Properties(); +// connectionProperties.put("user", "dremio"); +// connectionProperties.put("password", "dremio123"); +// long jdbcC = 0; +// long flightC = 0; +// for (int i=0;i<4;i++) { +// long now = System.currentTimeMillis(); +// Dataset jdbc = SQLContext.getOrCreate(sc.sc()).read().jdbc("jdbc:dremio:direct=localhost:31010", "\"@dremio\".sdd", connectionProperties); +// jdbcC = jdbc.count(); +// long then = System.currentTimeMillis(); +// flightC = csc.read("@dremio.sdd").count(); +// long andHereWeAre = System.currentTimeMillis(); +// jdbcT[i] = then-now; +// flightT[i] = andHereWeAre - then; +// } +// for (int i =0;i<16;i++) { +// System.out.println("Trial " + i + ": Flight took " + flightT[i] + " and jdbc took " + jdbcT[i]); +// } +// System.out.println("Fetched " + jdbcC + " row from jdbc and " + flightC + " from flight"); +// } } diff --git a/src/test/resources/logback-test.xml b/src/test/resources/logback-test.xml index 4a54f7d..1deee4a 100644 --- a/src/test/resources/logback-test.xml +++ b/src/test/resources/logback-test.xml @@ -1,13 +1,13 @@ + + META-INF.native.libnetty_ + META-INF.native.libcdap_netty_ + + + META-INF.native.netty_ + META-INF.native.cdap_netty_ + + + + true + shaded + + + + @@ -328,6 +421,14 @@ limitations under the License. log4j log4j + + org.apache.arrow + arrow-format + + + org.apache.arrow + arrow-vector + @@ -351,6 +452,14 @@ limitations under the License. javax.servlet servlet-api + + org.apache.arrow + arrow-format + + + org.apache.arrow + arrow-vector + @@ -359,7 +468,18 @@ limitations under the License. ${arrow.version} shaded + + org.scala-lang + scala-library + 2.11.6 + + + org.scalatest + scalatest_2.11 + 2.2.5 + compile + org.slf4j jul-to-slf4j diff --git a/src/main/java/com/dremio/spark/FlightClientFactory.java b/src/main/java/com/dremio/spark/FlightClientFactory.java index 11ea7f4..68b1321 100644 --- a/src/main/java/com/dremio/spark/FlightClientFactory.java +++ b/src/main/java/com/dremio/spark/FlightClientFactory.java @@ -29,7 +29,7 @@ public FlightClientFactory(BufferAllocator allocator, Location defaultLocation, this.allocator = allocator; this.defaultLocation = defaultLocation; this.username = username; - this.password = password; + this.password = password.equals("$NULL$") ? null : password; } public FlightClient apply() { @@ -39,4 +39,11 @@ public FlightClient apply() { } + public String getUsername() { + return username; + } + + public String getPassword() { + return password; + } } diff --git a/src/main/java/com/dremio/spark/FlightDataReader.java b/src/main/java/com/dremio/spark/FlightDataReader.java index 2115837..06e96f9 100644 --- a/src/main/java/com/dremio/spark/FlightDataReader.java +++ b/src/main/java/com/dremio/spark/FlightDataReader.java @@ -24,11 +24,13 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; -import org.apache.spark.sql.vectorized.ArrowColumnVector; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class FlightDataReader implements InputPartitionReader { + private final Logger logger = LoggerFactory.getLogger(FlightDataReader.class); private final FlightClient client; private final FlightStream stream; private final BufferAllocator allocator; @@ -36,10 +38,11 @@ public class FlightDataReader implements InputPartitionReader { public FlightDataReader( byte[] ticket, String defaultHost, - int defaultPort) { + int defaultPort, String username, String password) { this.allocator = new RootAllocator(); - client = FlightClient.builder(this.allocator, Location.forGrpcInsecure(defaultHost, defaultPort)).build(); //todo multiple locations - client.authenticateBasic("dremio", "dremio123"); + logger.warn("setting up a data reader at host {} and port {} with ticket {}", defaultHost, defaultPort, new String(ticket)); + client = FlightClient.builder(this.allocator, Location.forGrpcInsecure(defaultHost, defaultPort)).build(); //todo multiple locations & ssl + client.authenticateBasic(username, password); stream = client.getStream(new Ticket(ticket)); } @@ -53,7 +56,7 @@ public ColumnarBatch get() { ColumnarBatch batch = new ColumnarBatch( stream.getRoot().getFieldVectors() .stream() - .map(ArrowColumnVector::new) + .map(ModernArrowColumnVector::new) .toArray(ColumnVector[]::new) ); batch.setNumRows(stream.getRoot().getRowCount()); diff --git a/src/main/java/com/dremio/spark/FlightDataReaderFactory.java b/src/main/java/com/dremio/spark/FlightDataReaderFactory.java index 77c9b65..4e550ed 100644 --- a/src/main/java/com/dremio/spark/FlightDataReaderFactory.java +++ b/src/main/java/com/dremio/spark/FlightDataReaderFactory.java @@ -24,14 +24,18 @@ public class FlightDataReaderFactory implements InputPartition { private byte[] ticket; private final String defaultHost; private final int defaultPort; + private final String username; + private final String password; public FlightDataReaderFactory( byte[] ticket, String defaultHost, - int defaultPort) { + int defaultPort, String username, String password) { this.ticket = ticket; this.defaultHost = defaultHost; this.defaultPort = defaultPort; + this.username = username; + this.password = password; } @Override @@ -41,7 +45,7 @@ public String[] preferredLocations() { @Override public InputPartitionReader createPartitionReader() { - return new FlightDataReader(ticket, defaultHost, defaultPort); + return new FlightDataReader(ticket, defaultHost, defaultPort, username, password); } } diff --git a/src/main/java/com/dremio/spark/FlightDataSourceReader.java b/src/main/java/com/dremio/spark/FlightDataSourceReader.java index 8dd3733..cc311ef 100644 --- a/src/main/java/com/dremio/spark/FlightDataSourceReader.java +++ b/src/main/java/com/dremio/spark/FlightDataSourceReader.java @@ -33,7 +33,11 @@ import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.spark.sql.sources.EqualTo; import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.SupportsPushDownFilters; @@ -186,13 +190,16 @@ private List> planBatchInputPartitionsParallel() { } private List> planBatchInputPartitionsSerial(FlightInfo info) { + LOGGER.warn("planning partitions for endpoints {}", Joiner.on(", ").join(info.getEndpoints().stream().map(e -> e.getLocations().get(0).getUri().toString()).collect(Collectors.toList()))); return info.getEndpoints().stream().map(endpoint -> { Location location = (endpoint.getLocations().isEmpty()) ? Location.forGrpcInsecure(defaultLocation.getUri().getHost(), defaultLocation.getUri().getPort()) : endpoint.getLocations().get(0); return new FlightDataReaderFactory(endpoint.getTicket().getBytes(), location.getUri().getHost(), - location.getUri().getPort()); + location.getUri().getPort(), + clientFactory.getUsername(), + clientFactory.getPassword()); }).collect(Collectors.toList()); } @@ -238,6 +245,14 @@ private String generateWhereClause(List pushed) { filterStr.add(String.format("isnotnull(\"%s\")", ((IsNotNull) filter).attribute())); } else if (filter instanceof EqualTo) { filterStr.add(String.format("\"%s\" = %s", ((EqualTo) filter).attribute(), valueToString(((EqualTo) filter).value()))); + } else if (filter instanceof GreaterThan) { + filterStr.add(String.format("\"%s\" > %s", ((GreaterThan) filter).attribute(), valueToString(((GreaterThan) filter).value()))); + } else if (filter instanceof GreaterThanOrEqual) { + filterStr.add(String.format("\"%s\" <= %s", ((GreaterThanOrEqual) filter).attribute(), valueToString(((GreaterThanOrEqual) filter).value()))); + } else if (filter instanceof LessThan) { + filterStr.add(String.format("\"%s\" < %s", ((LessThan) filter).attribute(), valueToString(((LessThan) filter).value()))); + } else if (filter instanceof LessThanOrEqual) { + filterStr.add(String.format("\"%s\" <= %s", ((LessThanOrEqual) filter).attribute(), valueToString(((LessThanOrEqual) filter).value()))); } } return WHERE_JOINER.join(filterStr); @@ -256,6 +271,18 @@ private boolean canBePushed(Filter filter) { } else if (filter instanceof EqualTo) { return true; } + if (filter instanceof GreaterThan) { + return true; + } + if (filter instanceof GreaterThanOrEqual) { + return true; + } + if (filter instanceof LessThan) { + return true; + } + if (filter instanceof LessThanOrEqual) { + return true; + } LOGGER.error("Cant push filter of type " + filter.toString()); return false; } diff --git a/src/main/java/com/dremio/spark/ModernArrowColumnVector.java b/src/main/java/com/dremio/spark/ModernArrowColumnVector.java new file mode 100644 index 0000000..38bbf92 --- /dev/null +++ b/src/main/java/com/dremio/spark/ModernArrowColumnVector.java @@ -0,0 +1,495 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.dremio.spark; + +import io.netty.buffer.ArrowBuf; +import org.apache.arrow.vector.*; +import org.apache.arrow.vector.complex.*; +import org.apache.arrow.vector.holders.NullableVarCharHolder; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.execution.arrow.ModernArrowUtils; +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not + * supported. + */ +@InterfaceStability.Evolving +public final class ModernArrowColumnVector extends ColumnVector { + + private final ArrowVectorAccessor accessor; + private ModernArrowColumnVector[] childColumns; + + @Override + public boolean hasNull() { + return accessor.getNullCount() > 0; + } + + @Override + public int numNulls() { + return accessor.getNullCount(); + } + + @Override + public void close() { + if (childColumns != null) { + for (int i = 0; i < childColumns.length; i++) { + childColumns[i].close(); + childColumns[i] = null; + } + childColumns = null; + } + accessor.close(); + } + + @Override + public boolean isNullAt(int rowId) { + return accessor.isNullAt(rowId); + } + + @Override + public boolean getBoolean(int rowId) { + return accessor.getBoolean(rowId); + } + + @Override + public byte getByte(int rowId) { + return accessor.getByte(rowId); + } + + @Override + public short getShort(int rowId) { + return accessor.getShort(rowId); + } + + @Override + public int getInt(int rowId) { + return accessor.getInt(rowId); + } + + @Override + public long getLong(int rowId) { + return accessor.getLong(rowId); + } + + @Override + public float getFloat(int rowId) { + return accessor.getFloat(rowId); + } + + @Override + public double getDouble(int rowId) { + return accessor.getDouble(rowId); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; + return accessor.getDecimal(rowId, precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; + return accessor.getUTF8String(rowId); + } + + @Override + public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; + return accessor.getBinary(rowId); + } + + @Override + public ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) return null; + return accessor.getArray(rowId); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ModernArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } + + public ModernArrowColumnVector(ValueVector vector) { + super(ModernArrowUtils.fromArrowField(vector.getField())); + + if (vector instanceof BitVector) { + accessor = new BooleanAccessor((BitVector) vector); + } else if (vector instanceof TinyIntVector) { + accessor = new ByteAccessor((TinyIntVector) vector); + } else if (vector instanceof SmallIntVector) { + accessor = new ShortAccessor((SmallIntVector) vector); + } else if (vector instanceof IntVector) { + accessor = new IntAccessor((IntVector) vector); + } else if (vector instanceof BigIntVector) { + accessor = new LongAccessor((BigIntVector) vector); + } else if (vector instanceof Float4Vector) { + accessor = new FloatAccessor((Float4Vector) vector); + } else if (vector instanceof Float8Vector) { + accessor = new DoubleAccessor((Float8Vector) vector); + } else if (vector instanceof DecimalVector) { + accessor = new DecimalAccessor((DecimalVector) vector); + } else if (vector instanceof VarCharVector) { + accessor = new StringAccessor((VarCharVector) vector); + } else if (vector instanceof VarBinaryVector) { + accessor = new BinaryAccessor((VarBinaryVector) vector); + } else if (vector instanceof DateDayVector) { + accessor = new DateAccessor((DateDayVector) vector); + } else if (vector instanceof TimeStampMicroTZVector) { + accessor = new TimestampAccessor((TimeStampMicroTZVector) vector); + } else if (vector instanceof ListVector) { + ListVector listVector = (ListVector) vector; + accessor = new ArrayAccessor(listVector); + } else if (vector instanceof StructVector) { + StructVector structVector = (StructVector) vector; + accessor = new StructAccessor(structVector); + + childColumns = new ModernArrowColumnVector[structVector.size()]; + for (int i = 0; i < childColumns.length; ++i) { + childColumns[i] = new ModernArrowColumnVector(structVector.getVectorById(i)); + } + } else { + throw new UnsupportedOperationException(); + } + } + + private abstract static class ArrowVectorAccessor { + + private final ValueVector vector; + + ArrowVectorAccessor(ValueVector vector) { + this.vector = vector; + } + + // TODO: should be final after removing ArrayAccessor workaround + boolean isNullAt(int rowId) { + return vector.isNull(rowId); + } + + final int getNullCount() { + return vector.getNullCount(); + } + + final void close() { + vector.close(); + } + + boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + } + + private static class BooleanAccessor extends ArrowVectorAccessor { + + private final BitVector accessor; + + BooleanAccessor(BitVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final boolean getBoolean(int rowId) { + return accessor.get(rowId) == 1; + } + } + + private static class ByteAccessor extends ArrowVectorAccessor { + + private final TinyIntVector accessor; + + ByteAccessor(TinyIntVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final byte getByte(int rowId) { + return accessor.get(rowId); + } + } + + private static class ShortAccessor extends ArrowVectorAccessor { + + private final SmallIntVector accessor; + + ShortAccessor(SmallIntVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final short getShort(int rowId) { + return accessor.get(rowId); + } + } + + private static class IntAccessor extends ArrowVectorAccessor { + + private final IntVector accessor; + + IntAccessor(IntVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final int getInt(int rowId) { + return accessor.get(rowId); + } + } + + private static class LongAccessor extends ArrowVectorAccessor { + + private final BigIntVector accessor; + + LongAccessor(BigIntVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final long getLong(int rowId) { + return accessor.get(rowId); + } + } + + private static class FloatAccessor extends ArrowVectorAccessor { + + private final Float4Vector accessor; + + FloatAccessor(Float4Vector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final float getFloat(int rowId) { + return accessor.get(rowId); + } + } + + private static class DoubleAccessor extends ArrowVectorAccessor { + + private final Float8Vector accessor; + + DoubleAccessor(Float8Vector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final double getDouble(int rowId) { + return accessor.get(rowId); + } + } + + private static class DecimalAccessor extends ArrowVectorAccessor { + + private final DecimalVector accessor; + + DecimalAccessor(DecimalVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; + return Decimal.apply(accessor.getObject(rowId), precision, scale); + } + } + + private static class StringAccessor extends ArrowVectorAccessor { + + private final VarCharVector accessor; + private final NullableVarCharHolder stringResult = new NullableVarCharHolder(); + + StringAccessor(VarCharVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final UTF8String getUTF8String(int rowId) { + accessor.get(rowId, stringResult); + if (stringResult.isSet == 0) { + return null; + } else { + return UTF8String.fromAddress(null, + stringResult.buffer.memoryAddress() + stringResult.start, + stringResult.end - stringResult.start); + } + } + } + + private static class BinaryAccessor extends ArrowVectorAccessor { + + private final VarBinaryVector accessor; + + BinaryAccessor(VarBinaryVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final byte[] getBinary(int rowId) { + return accessor.getObject(rowId); + } + } + + private static class DateAccessor extends ArrowVectorAccessor { + + private final DateDayVector accessor; + + DateAccessor(DateDayVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final int getInt(int rowId) { + return accessor.get(rowId); + } + } + + private static class TimestampAccessor extends ArrowVectorAccessor { + + private final TimeStampMicroTZVector accessor; + + TimestampAccessor(TimeStampMicroTZVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final long getLong(int rowId) { + return accessor.get(rowId); + } + } + + private static class ArrayAccessor extends ArrowVectorAccessor { + + private final ListVector accessor; + private final ModernArrowColumnVector arrayData; + + ArrayAccessor(ListVector vector) { + super(vector); + this.accessor = vector; + this.arrayData = new ModernArrowColumnVector(vector.getDataVector()); + } + + @Override + final boolean isNullAt(int rowId) { + // TODO: Workaround if vector has all non-null values, see ARROW-1948 + if (accessor.getValueCount() > 0 && accessor.getValidityBuffer().capacity() == 0) { + return false; + } else { + return super.isNullAt(rowId); + } + } + + @Override + final ColumnarArray getArray(int rowId) { + ArrowBuf offsets = accessor.getOffsetBuffer(); + int index = rowId * ListVector.OFFSET_WIDTH; + int start = offsets.getInt(index); + int end = offsets.getInt(index + ListVector.OFFSET_WIDTH); + return new ColumnarArray(arrayData, start, end - start); + } + } + + /** + * Any call to "get" method will throw UnsupportedOperationException. + * + * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses + * getStruct() method defined in the parent class. Any call to "get" method in this class is a + * bug in the code. + * + */ + private static class StructAccessor extends ArrowVectorAccessor { + + StructAccessor(StructVector vector) { + super(vector); + } + } +} diff --git a/src/main/scala/org/apache/spark/sql/execution/arrow/ModernArrowUtils.scala b/src/main/scala/org/apache/spark/sql/execution/arrow/ModernArrowUtils.scala new file mode 100644 index 0000000..aef1915 --- /dev/null +++ b/src/main/scala/org/apache/spark/sql/execution/arrow/ModernArrowUtils.scala @@ -0,0 +1,134 @@ +/** + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.arrow + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import scala.collection.JavaConverters._ + +object ModernArrowUtils { + + val rootAllocator = new RootAllocator(Long.MaxValue) + + // todo: support more types. + + /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ + def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match { + case BooleanType => ArrowType.Bool.INSTANCE + case ByteType => new ArrowType.Int(8, true) + case ShortType => new ArrowType.Int(8 * 2, true) + case IntegerType => new ArrowType.Int(8 * 4, true) + case LongType => new ArrowType.Int(8 * 8, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case StringType => ArrowType.Utf8.INSTANCE + case BinaryType => ArrowType.Binary.INSTANCE + case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) + case DateType => new ArrowType.Date(DateUnit.DAY) + case TimestampType => + if (timeZoneId == null) { + throw new UnsupportedOperationException( + s"${TimestampType.catalogString} must supply timeZoneId parameter") + } else { + new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) + } + case _ => + throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}") + } + + def fromArrowType(dt: ArrowType): DataType = dt match { + case ArrowType.Bool.INSTANCE => BooleanType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType + case float: ArrowType.FloatingPoint + if float.getPrecision() == FloatingPointPrecision.SINGLE => FloatType + case float: ArrowType.FloatingPoint + if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType + case ArrowType.Utf8.INSTANCE => StringType + case ArrowType.Binary.INSTANCE => BinaryType + case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) + case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType + case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType + case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt") + } + + /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ + def toArrowField( + name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = { + dt match { + case ArrayType(elementType, containsNull) => + val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) + new Field(name, fieldType, + Seq(toArrowField("element", elementType, containsNull, timeZoneId)).asJava) + case StructType(fields) => + val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) + new Field(name, fieldType, + fields.map { field => + toArrowField(field.name, field.dataType, field.nullable, timeZoneId) + }.toSeq.asJava) + case dataType => + val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null) + new Field(name, fieldType, Seq.empty[Field].asJava) + } + } + + def fromArrowField(field: Field): DataType = { + field.getType match { + case ArrowType.List.INSTANCE => + val elementField = field.getChildren().get(0) + val elementType = fromArrowField(elementField) + ArrayType(elementType, containsNull = elementField.isNullable) + case ArrowType.Struct.INSTANCE => + val fields = field.getChildren().asScala.map { child => + val dt = fromArrowField(child) + StructField(child.getName, dt, child.isNullable) + } + StructType(fields) + case arrowType => fromArrowType(arrowType) + } + } + + /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */ + def toArrowSchema(schema: StructType, timeZoneId: String): Schema = { + new Schema(schema.map { field => + toArrowField(field.name, field.dataType, field.nullable, timeZoneId) + }.asJava) + } + + def fromArrowSchema(schema: Schema): StructType = { + StructType(schema.getFields.asScala.map { field => + val dt = fromArrowField(field) + StructField(field.getName, dt, field.isNullable) + }) + } + + /** Return Map with conf settings to be used in ArrowPythonRunner */ + def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { + val timeZoneConf = if (conf.pandasRespectSessionTimeZone) { + Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) + } else { + Nil + } + val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key -> + conf.pandasGroupedMapAssignColumnsByName.toString) + Map(timeZoneConf ++ pandasColsByName: _*) + } +} From 83892ea5f338bf94e19a3397123f3c41f15526e6 Mon Sep 17 00:00:00 2001 From: Ryan Murray Date: Sun, 26 Jan 2020 15:39:49 +0000 Subject: [PATCH 13/38] fix tests --- pom.xml | 8 +- .../com/dremio/spark/FlightDataReader.java | 13 +- .../dremio/spark/FlightDataSourceReader.java | 4 +- .../java/com/dremio/spark/TestConnector.java | 353 ++++++++++++------ 4 files changed, 259 insertions(+), 119 deletions(-) diff --git a/pom.xml b/pom.xml index cc0f66a..015d2c2 100644 --- a/pom.xml +++ b/pom.xml @@ -29,7 +29,6 @@ 0.15.1 2.4.4 1.7.25 - 1.4.4 @@ -486,6 +485,13 @@ limitations under the License. ${dep.slf4j.version} test + + org.apache.arrow + arrow-flight + ${arrow.version} + tests + test + org.slf4j diff --git a/src/main/java/com/dremio/spark/FlightDataReader.java b/src/main/java/com/dremio/spark/FlightDataReader.java index 06e96f9..7f43473 100644 --- a/src/main/java/com/dremio/spark/FlightDataReader.java +++ b/src/main/java/com/dremio/spark/FlightDataReader.java @@ -65,13 +65,12 @@ public ColumnarBatch get() { @Override public void close() throws IOException { - -// try { -// client.close(); -// stream.close(); + try { + client.close(); + stream.close(); // allocator.close(); -// } catch (Exception e) { -// throw new IOException(e); -// } + } catch (Exception e) { + throw new IOException(e); + } } } diff --git a/src/main/java/com/dremio/spark/FlightDataSourceReader.java b/src/main/java/com/dremio/spark/FlightDataSourceReader.java index cc311ef..d71bf6c 100644 --- a/src/main/java/com/dremio/spark/FlightDataSourceReader.java +++ b/src/main/java/com/dremio/spark/FlightDataSourceReader.java @@ -191,7 +191,7 @@ private List> planBatchInputPartitionsParallel() { private List> planBatchInputPartitionsSerial(FlightInfo info) { LOGGER.warn("planning partitions for endpoints {}", Joiner.on(", ").join(info.getEndpoints().stream().map(e -> e.getLocations().get(0).getUri().toString()).collect(Collectors.toList()))); - return info.getEndpoints().stream().map(endpoint -> { + List> batches = info.getEndpoints().stream().map(endpoint -> { Location location = (endpoint.getLocations().isEmpty()) ? Location.forGrpcInsecure(defaultLocation.getUri().getHost(), defaultLocation.getUri().getPort()) : endpoint.getLocations().get(0); @@ -201,6 +201,8 @@ private List> planBatchInputPartitionsSerial(Fligh clientFactory.getUsername(), clientFactory.getPassword()); }).collect(Collectors.toList()); + LOGGER.info("Created {} batches from arrow endpoints", batches.size()); + return batches; } @Override diff --git a/src/test/java/com/dremio/spark/TestConnector.java b/src/test/java/com/dremio/spark/TestConnector.java index c181eeb..f9c43f5 100644 --- a/src/test/java/com/dremio/spark/TestConnector.java +++ b/src/test/java/com/dremio/spark/TestConnector.java @@ -15,140 +15,273 @@ */ package com.dremio.spark; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import java.util.function.Consumer; + +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightTestUtil; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.auth.ServerAuthHandler; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; - -import java.util.Properties; -import java.util.function.Consumer; +import com.google.common.collect.ImmutableList; public class TestConnector { - private static SparkConf conf; - private static JavaSparkContext sc; - private static FlightSparkContext csc; - - @BeforeClass - public static void setUp() throws Exception { - conf = new SparkConf() - .setAppName("flightTest") - .setMaster("local[*]") - .set("spark.driver.allowMultipleContexts","true") - .set("spark.flight.endpoint.host", "localhost") - .set("spark.flight.endpoint.port", "47470") - .set("spark.flight.auth.username", "dremio") - .set("spark.flight.auth.password", "dremio123") - ; - sc = new JavaSparkContext(conf); - - csc = FlightSparkContext.flightContext(sc); - } + private static final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + private static Location location; + private static FlightServer server; + private static SparkConf conf; + private static JavaSparkContext sc; + private static FlightSparkContext csc; - @AfterClass - public static void tearDown() throws Exception { - sc.close(); - } + @BeforeClass + public static void setUp() throws Exception { + server = FlightTestUtil.getStartedServer(location -> FlightServer.builder(allocator, location, new TestProducer()).authHandler( + new ServerAuthHandler() { + @Override + public Optional isValid(byte[] token) { + return Optional.of("xxx"); + } - @Test - public void testConnect() { - csc.read("sys.options"); - } + @Override + public boolean authenticate(ServerAuthSender outgoing, Iterator incoming) { + incoming.next(); + outgoing.send(new byte[0]); + return true; + } + }).build() + ); + location = server.getLocation(); + conf = new SparkConf() + .setAppName("flightTest") + .setMaster("local[*]") + .set("spark.driver.allowMultipleContexts", "true") + .set("spark.flight.endpoint.host", location.getUri().getHost()) + .set("spark.flight.endpoint.port", Integer.toString(location.getUri().getPort())) + .set("spark.flight.auth.username", "xxx") + .set("spark.flight.auth.password", "yyy") + ; + sc = new JavaSparkContext(conf); + csc = FlightSparkContext.flightContext(sc); + } - @Test - public void testRead() { - long count = csc.read("sys.options").count(); - Assert.assertTrue(count > 0); - } + @AfterClass + public static void tearDown() throws Exception { + AutoCloseables.close(server, allocator, sc); + } + + @Test + public void testConnect() { + csc.read("test.table"); + } + + @Test + public void testRead() { + long count = csc.read("test.table").count(); + Assert.assertEquals(20, count); + } + + @Test + public void testSql() { + long count = csc.readSql("select * from test.table").count(); + Assert.assertEquals(20, count); + } + + @Test + public void testFilter() { + Dataset df = csc.readSql("select * from test.table"); + long count = df.filter(df.col("symbol").equalTo("USDCAD")).count(); + long countOriginal = csc.readSql("select * from test.table").count(); + Assert.assertTrue(count < countOriginal); + } - @Test - public void testReadWithQuotes() { - long count = csc.read("\"sys\".options").count(); - Assert.assertTrue(count > 0); + private static class SizeConsumer implements Consumer { + private int length = 0; + private int width = 0; + + @Override + public void accept(Row row) { + length += 1; + width = row.length(); } + } + + @Test + public void testProject() { + Dataset df = csc.readSql("select * from test.table"); + SizeConsumer c = new SizeConsumer(); + df.select("bid", "ask", "symbol").toLocalIterator().forEachRemaining(c); + long count = c.width; + long countOriginal = csc.readSql("select * from test.table").columns().length; + Assert.assertTrue(count < countOriginal); + } + + @Test + public void testParallel() { + String easySql = "select * from test.table"; + SizeConsumer c = new SizeConsumer(); + csc.readSql(easySql, true).toLocalIterator().forEachRemaining(c); + long width = c.width; + long length = c.length; + Assert.assertEquals(5, width); + Assert.assertEquals(40, length); + } - @Test - public void testSql() { - long count = csc.readSql("select * from \"sys\".options").count(); - Assert.assertTrue(count > 0); + private static class TestProducer extends NoOpFlightProducer { + private boolean parallel = false; + + @Override + public void doAction(CallContext context, Action action, StreamListener listener) { + parallel = true; + listener.onNext(new Result("ok".getBytes())); + listener.onCompleted(); } - @Test - public void testFilter() { - Dataset df = csc.readSql("select * from \"sys\".options"); - long count = df.filter(df.col("kind").equalTo("LONG")).count(); - long countOriginal = csc.readSql("select * from \"sys\".options").count(); - Assert.assertTrue(count < countOriginal); + @Override + public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + Schema schema; + List endpoints; + if (parallel) { + endpoints = ImmutableList.of(new FlightEndpoint(new Ticket(descriptor.getCommand()), location), + new FlightEndpoint(new Ticket(descriptor.getCommand()), location)); + } else { + endpoints = ImmutableList.of(new FlightEndpoint(new Ticket(descriptor.getCommand()), location)); + } + if (new String(descriptor.getCommand()).equals("select \"bid\", \"ask\", \"symbol\" from (select * from test.table))")) { + schema = new Schema(ImmutableList.of( + Field.nullable("bid", Types.MinorType.FLOAT8.getType()), + Field.nullable("ask", Types.MinorType.FLOAT8.getType()), + Field.nullable("symbol", Types.MinorType.VARCHAR.getType())) + ); + + } else { + schema = new Schema(ImmutableList.of( + Field.nullable("bid", Types.MinorType.FLOAT8.getType()), + Field.nullable("ask", Types.MinorType.FLOAT8.getType()), + Field.nullable("symbol", Types.MinorType.VARCHAR.getType()), + Field.nullable("bidsize", Types.MinorType.BIGINT.getType()), + Field.nullable("asksize", Types.MinorType.BIGINT.getType())) + ); + } + return new FlightInfo(schema, descriptor, endpoints, 1000000, 10); } - private static class SizeConsumer implements Consumer { - private int length = 0; - private int width = 0; + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + final int size = (new String(ticket.getBytes()).contains("USDCAD")) ? 5 : 10; - @Override - public void accept(Row row) { - length+=1; - width = row.length(); + if (new String(ticket.getBytes()).equals("select \"bid\", \"ask\", \"symbol\" from (select * from test.table))")) { + Float8Vector b = new Float8Vector("bid", allocator); + Float8Vector a = new Float8Vector("ask", allocator); + VarCharVector s = new VarCharVector("symbol", allocator); + + VectorSchemaRoot root = VectorSchemaRoot.of(b, a, s); + listener.start(root); + + //batch 1 + root.allocateNew(); + for (int i = 0; i < size; i++) { + b.set(i, (double) i); + a.set(i, (double) i); + s.set(i, (i % 2 == 0) ? new Text("USDCAD") : new Text("EURUSD")); } - } + b.setValueCount(size); + a.setValueCount(size); + s.setValueCount(size); + root.setRowCount(size); + listener.putNext(); - @Test - public void testProject() { - Dataset df = csc.readSql("select * from \"sys\".options"); - SizeConsumer c = new SizeConsumer(); - df.select("name", "kind", "type").toLocalIterator().forEachRemaining(c); - long count = c.width; - long countOriginal = csc.readSql("select * from \"sys\".options").columns().length; - Assert.assertTrue(count < countOriginal); - } + // batch 2 + + root.allocateNew(); + for (int i = 0; i < size; i++) { + b.set(i, (double) i); + a.set(i, (double) i); + s.set(i, (i % 2 == 0) ? new Text("USDCAD") : new Text("EURUSD")); + } + b.setValueCount(size); + a.setValueCount(size); + s.setValueCount(size); + root.setRowCount(size); + listener.putNext(); + root.clear(); + listener.completed(); + } else { + BigIntVector bs = new BigIntVector("bidsize", allocator); + BigIntVector as = new BigIntVector("asksize", allocator); + Float8Vector b = new Float8Vector("bid", allocator); + Float8Vector a = new Float8Vector("ask", allocator); + VarCharVector s = new VarCharVector("symbol", allocator); - @Test - public void testParallel() { - String easySql = "select * from sys.options"; -// String hardSql = "select * from \"@dremio\".test"; - Dataset df = csc.readSql(easySql, true); - SizeConsumer c = new SizeConsumer(); - SizeConsumer c2 = new SizeConsumer(); - Dataset dff = df.select("bid", "ask", "symbol").filter(df.col("symbol").equalTo("USDCAD")); - dff.toLocalIterator().forEachRemaining(c); - long width = c.width; - long length = c.length; - csc.readSql(easySql, true).toLocalIterator().forEachRemaining(c2); - long widthOriginal = c2.width; - long lengthOriginal = c2.length; - Assert.assertTrue(width < widthOriginal); - Assert.assertTrue(length < lengthOriginal); + VectorSchemaRoot root = VectorSchemaRoot.of(b, a, s, bs, as); + listener.start(root); + + //batch 1 + root.allocateNew(); + for (int i = 0; i < size; i++) { + bs.set(i, (long) i); + as.set(i, (long) i); + b.set(i, (double) i); + a.set(i, (double) i); + s.set(i, (i % 2 == 0) ? new Text("USDCAD") : new Text("EURUSD")); + } + bs.setValueCount(size); + as.setValueCount(size); + b.setValueCount(size); + a.setValueCount(size); + s.setValueCount(size); + root.setRowCount(size); + listener.putNext(); + + // batch 2 + + root.allocateNew(); + for (int i = 0; i < size; i++) { + bs.set(i, (long) i); + as.set(i, (long) i); + b.set(i, (double) i); + a.set(i, (double) i); + s.set(i, (i % 2 == 0) ? new Text("USDCAD") : new Text("EURUSD")); + } + bs.setValueCount(size); + as.setValueCount(size); + b.setValueCount(size); + a.setValueCount(size); + s.setValueCount(size); + root.setRowCount(size); + listener.putNext(); + root.clear(); + listener.completed(); + } } -// @Ignore -// @Test -// public void testSpeed() { -// long[] jdbcT = new long[16]; -// long[] flightT = new long[16]; -// Properties connectionProperties = new Properties(); -// connectionProperties.put("user", "dremio"); -// connectionProperties.put("password", "dremio123"); -// long jdbcC = 0; -// long flightC = 0; -// for (int i=0;i<4;i++) { -// long now = System.currentTimeMillis(); -// Dataset jdbc = SQLContext.getOrCreate(sc.sc()).read().jdbc("jdbc:dremio:direct=localhost:31010", "\"@dremio\".sdd", connectionProperties); -// jdbcC = jdbc.count(); -// long then = System.currentTimeMillis(); -// flightC = csc.read("@dremio.sdd").count(); -// long andHereWeAre = System.currentTimeMillis(); -// jdbcT[i] = then-now; -// flightT[i] = andHereWeAre - then; -// } -// for (int i =0;i<16;i++) { -// System.out.println("Trial " + i + ": Flight took " + flightT[i] + " and jdbc took " + jdbcT[i]); -// } -// System.out.println("Fetched " + jdbcC + " row from jdbc and " + flightC + " from flight"); -// } + + } } From 4392f61b6905ad7c2c72cca605a7efbb1d583407 Mon Sep 17 00:00:00 2001 From: Ryan Murray Date: Sun, 26 Jan 2020 20:54:22 +0000 Subject: [PATCH 14/38] relocate --- pom.xml | 2 +- .../arrow/flight}/spark/DefaultSource.java | 2 +- .../flight}/spark/FlightClientFactory.java | 18 ++++++++++++++++-- .../arrow/flight}/spark/FlightDataReader.java | 11 +++++++++-- .../flight}/spark/FlightDataReaderFactory.java | 8 +++++--- .../flight}/spark/FlightDataSourceReader.java | 16 ++++++---------- .../flight}/spark/FlightSparkContext.java | 4 ++-- .../flight}/spark/ModernArrowColumnVector.java | 2 +- .../arrow/flight}/spark/TestConnector.java | 2 +- 9 files changed, 42 insertions(+), 23 deletions(-) rename src/main/java/{com/dremio => org/apache/arrow/flight}/spark/DefaultSource.java (96%) rename src/main/java/{com/dremio => org/apache/arrow/flight}/spark/FlightClientFactory.java (75%) rename src/main/java/{com/dremio => org/apache/arrow/flight}/spark/FlightDataReader.java (87%) rename src/main/java/{com/dremio => org/apache/arrow/flight}/spark/FlightDataReaderFactory.java (88%) rename src/main/java/{com/dremio => org/apache/arrow/flight}/spark/FlightDataSourceReader.java (96%) rename src/main/java/{com/dremio => org/apache/arrow/flight}/spark/FlightSparkContext.java (96%) rename src/main/java/{com/dremio => org/apache/arrow/flight}/spark/ModernArrowColumnVector.java (99%) rename src/test/java/{com/dremio => org/apache/arrow/flight}/spark/TestConnector.java (99%) diff --git a/pom.xml b/pom.xml index 015d2c2..3524c69 100644 --- a/pom.xml +++ b/pom.xml @@ -21,7 +21,7 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 - com.dremio + org.apache.arrow.flight.spark flight-spark-source 1.0-SNAPSHOT diff --git a/src/main/java/com/dremio/spark/DefaultSource.java b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java similarity index 96% rename from src/main/java/com/dremio/spark/DefaultSource.java rename to src/main/java/org/apache/arrow/flight/spark/DefaultSource.java index f7af1da..68fc5ec 100644 --- a/src/main/java/com/dremio/spark/DefaultSource.java +++ b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.dremio.spark; +package org.apache.arrow.flight.spark; import org.apache.arrow.memory.RootAllocator; import org.apache.spark.sql.sources.v2.DataSourceOptions; diff --git a/src/main/java/com/dremio/spark/FlightClientFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java similarity index 75% rename from src/main/java/com/dremio/spark/FlightClientFactory.java rename to src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java index 68b1321..3473d1f 100644 --- a/src/main/java/com/dremio/spark/FlightClientFactory.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.dremio.spark; +package org.apache.arrow.flight.spark; +import java.util.Iterator; + +import org.apache.arrow.flight.Action; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Result; import org.apache.arrow.memory.BufferAllocator; public class FlightClientFactory { @@ -24,17 +28,23 @@ public class FlightClientFactory { private Location defaultLocation; private final String username; private final String password; + private boolean parallel; - public FlightClientFactory(BufferAllocator allocator, Location defaultLocation, String username, String password) { + public FlightClientFactory(BufferAllocator allocator, Location defaultLocation, String username, String password, boolean parallel) { this.allocator = allocator; this.defaultLocation = defaultLocation; this.username = username; this.password = password.equals("$NULL$") ? null : password; + this.parallel = parallel; } public FlightClient apply() { FlightClient client = FlightClient.builder(allocator, defaultLocation).build(); client.authenticateBasic(username, password); + if (parallel) { + Iterator res = client.doAction(new Action("PARALLEL")); + res.forEachRemaining(Object::toString); + } return client; } @@ -46,4 +56,8 @@ public String getUsername() { public String getPassword() { return password; } + + public boolean isParallel() { + return parallel; + } } diff --git a/src/main/java/com/dremio/spark/FlightDataReader.java b/src/main/java/org/apache/arrow/flight/spark/FlightDataReader.java similarity index 87% rename from src/main/java/com/dremio/spark/FlightDataReader.java rename to src/main/java/org/apache/arrow/flight/spark/FlightDataReader.java index 7f43473..52b3145 100644 --- a/src/main/java/com/dremio/spark/FlightDataReader.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightDataReader.java @@ -13,13 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.dremio.spark; +package org.apache.arrow.flight.spark; import java.io.IOException; +import java.util.Iterator; +import org.apache.arrow.flight.Action; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Result; import org.apache.arrow.flight.Ticket; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -38,11 +41,15 @@ public class FlightDataReader implements InputPartitionReader { public FlightDataReader( byte[] ticket, String defaultHost, - int defaultPort, String username, String password) { + int defaultPort, String username, String password, boolean parallel) { this.allocator = new RootAllocator(); logger.warn("setting up a data reader at host {} and port {} with ticket {}", defaultHost, defaultPort, new String(ticket)); client = FlightClient.builder(this.allocator, Location.forGrpcInsecure(defaultHost, defaultPort)).build(); //todo multiple locations & ssl client.authenticateBasic(username, password); + if (parallel) { + Iterator res = client.doAction(new Action("PARALLEL")); + res.forEachRemaining(Object::toString); + } stream = client.getStream(new Ticket(ticket)); } diff --git a/src/main/java/com/dremio/spark/FlightDataReaderFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightDataReaderFactory.java similarity index 88% rename from src/main/java/com/dremio/spark/FlightDataReaderFactory.java rename to src/main/java/org/apache/arrow/flight/spark/FlightDataReaderFactory.java index 4e550ed..667222f 100644 --- a/src/main/java/com/dremio/spark/FlightDataReaderFactory.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightDataReaderFactory.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.dremio.spark; +package org.apache.arrow.flight.spark; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; @@ -26,16 +26,18 @@ public class FlightDataReaderFactory implements InputPartition { private final int defaultPort; private final String username; private final String password; + private boolean parallel; public FlightDataReaderFactory( byte[] ticket, String defaultHost, - int defaultPort, String username, String password) { + int defaultPort, String username, String password, boolean parallel) { this.ticket = ticket; this.defaultHost = defaultHost; this.defaultPort = defaultPort; this.username = username; this.password = password; + this.parallel = parallel; } @Override @@ -45,7 +47,7 @@ public String[] preferredLocations() { @Override public InputPartitionReader createPartitionReader() { - return new FlightDataReader(ticket, defaultHost, defaultPort, username, password); + return new FlightDataReader(ticket, defaultHost, defaultPort, username, password, parallel); } } diff --git a/src/main/java/com/dremio/spark/FlightDataSourceReader.java b/src/main/java/org/apache/arrow/flight/spark/FlightDataSourceReader.java similarity index 96% rename from src/main/java/com/dremio/spark/FlightDataSourceReader.java rename to src/main/java/org/apache/arrow/flight/spark/FlightDataSourceReader.java index d71bf6c..077f21f 100644 --- a/src/main/java/com/dremio/spark/FlightDataSourceReader.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightDataSourceReader.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.dremio.spark; +package org.apache.arrow.flight.spark; import java.util.Iterator; import java.util.List; @@ -66,7 +66,6 @@ public class FlightDataSourceReader implements SupportsScanColumnarBatch, Suppor private StructType schema; private final Location defaultLocation; private final FlightClientFactory clientFactory; - private final boolean parallel; private String sql; private Filter[] pushed; @@ -78,16 +77,12 @@ public FlightDataSourceReader(DataSourceOptions dataSourceOptions, BufferAllocat clientFactory = new FlightClientFactory(allocator, defaultLocation, dataSourceOptions.get("username").orElse("anonymous"), - dataSourceOptions.get("password").orElse(null) + dataSourceOptions.get("password").orElse(null), + dataSourceOptions.getBoolean("parallel", false) ); - parallel = dataSourceOptions.getBoolean("parallel", false); sql = dataSourceOptions.get("path").orElse(""); descriptor = getDescriptor(sql); try (FlightClient client = clientFactory.apply()) { - if (parallel) { - Iterator res = client.doAction(new Action("PARALLEL")); - res.forEachRemaining(Object::toString); - } info = client.getSchema(descriptor); } catch (InterruptedException e) { throw new RuntimeException(e); @@ -140,7 +135,7 @@ private DataType sparkFromArrow(FieldType fieldType) { } else if (bitWidth == 64) { return DataTypes.LongType; } - throw new UnsupportedOperationException("unknow int type with bitwidth " + bitWidth); + throw new UnsupportedOperationException("unknown int type with bitwidth " + bitWidth); case FloatingPoint: ArrowType.FloatingPoint floatType = (ArrowType.FloatingPoint) fieldType.getType(); FloatingPointPrecision precision = floatType.getPrecision(); @@ -199,7 +194,8 @@ private List> planBatchInputPartitionsSerial(Fligh location.getUri().getHost(), location.getUri().getPort(), clientFactory.getUsername(), - clientFactory.getPassword()); + clientFactory.getPassword(), + clientFactory.isParallel()); }).collect(Collectors.toList()); LOGGER.info("Created {} batches from arrow endpoints", batches.size()); return batches; diff --git a/src/main/java/com/dremio/spark/FlightSparkContext.java b/src/main/java/org/apache/arrow/flight/spark/FlightSparkContext.java similarity index 96% rename from src/main/java/com/dremio/spark/FlightSparkContext.java rename to src/main/java/org/apache/arrow/flight/spark/FlightSparkContext.java index ce911e5..c6d65ec 100644 --- a/src/main/java/com/dremio/spark/FlightSparkContext.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightSparkContext.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.dremio.spark; +package org.apache.arrow.flight.spark; import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; @@ -31,7 +31,7 @@ public class FlightSparkContext { private FlightSparkContext(SparkContext sc, SparkConf conf) { SQLContext sqlContext = SQLContext.getOrCreate(sc); this.conf = conf; - reader = sqlContext.read().format("com.dremio.spark"); + reader = sqlContext.read().format("org.apache.arrow.flight.spark"); } public static FlightSparkContext flightContext(JavaSparkContext sc) { diff --git a/src/main/java/com/dremio/spark/ModernArrowColumnVector.java b/src/main/java/org/apache/arrow/flight/spark/ModernArrowColumnVector.java similarity index 99% rename from src/main/java/com/dremio/spark/ModernArrowColumnVector.java rename to src/main/java/org/apache/arrow/flight/spark/ModernArrowColumnVector.java index 38bbf92..28a76e9 100644 --- a/src/main/java/com/dremio/spark/ModernArrowColumnVector.java +++ b/src/main/java/org/apache/arrow/flight/spark/ModernArrowColumnVector.java @@ -30,7 +30,7 @@ * limitations under the License. */ -package com.dremio.spark; +package org.apache.arrow.flight.spark; import io.netty.buffer.ArrowBuf; import org.apache.arrow.vector.*; diff --git a/src/test/java/com/dremio/spark/TestConnector.java b/src/test/java/org/apache/arrow/flight/spark/TestConnector.java similarity index 99% rename from src/test/java/com/dremio/spark/TestConnector.java rename to src/test/java/org/apache/arrow/flight/spark/TestConnector.java index f9c43f5..861691f 100644 --- a/src/test/java/com/dremio/spark/TestConnector.java +++ b/src/test/java/org/apache/arrow/flight/spark/TestConnector.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.dremio.spark; +package org.apache.arrow.flight.spark; import java.util.Iterator; import java.util.List; From 91ffd8dc8c036060f697051ebd1deeaec8673bad Mon Sep 17 00:00:00 2001 From: Ryan Murray Date: Mon, 17 Feb 2020 19:05:20 +0000 Subject: [PATCH 15/38] Cleanup and fixes for v1.0 * fix a few bugs * clean up code * correctly cancel etc * correctly close on failure * handle datetimes correctly --- .../arrow/flight/spark/DefaultSource.java | 4 +- .../flight/spark/FlightClientFactory.java | 17 ++-- .../arrow/flight/spark/FlightDataReader.java | 31 +++++-- .../flight/spark/FlightDataSourceReader.java | 16 ++-- .../flight/spark/ModernArrowColumnVector.java | 92 ++++++++++++++++--- .../execution/arrow/ModernArrowUtils.scala | 4 +- .../arrow/flight/spark/TestConnector.java | 2 +- 7 files changed, 123 insertions(+), 43 deletions(-) diff --git a/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java index 68fc5ec..f814aaf 100644 --- a/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java +++ b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java @@ -15,16 +15,14 @@ */ package org.apache.arrow.flight.spark; -import org.apache.arrow.memory.RootAllocator; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.DataSourceReader; public class DefaultSource implements DataSourceV2, ReadSupport { - private final RootAllocator rootAllocator = new RootAllocator(); public DataSourceReader createReader(DataSourceOptions dataSourceOptions) { - return new FlightDataSourceReader(dataSourceOptions, rootAllocator.newChildAllocator(dataSourceOptions.toString(), 0, rootAllocator.getLimit())); + return new FlightDataSourceReader(dataSourceOptions); } } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java index 3473d1f..a8d09c4 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java @@ -22,16 +22,16 @@ import org.apache.arrow.flight.Location; import org.apache.arrow.flight.Result; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; -public class FlightClientFactory { - private BufferAllocator allocator; - private Location defaultLocation; +public class FlightClientFactory implements AutoCloseable { + private final BufferAllocator allocator = new RootAllocator(); + private final Location defaultLocation; private final String username; private final String password; - private boolean parallel; + private final boolean parallel; - public FlightClientFactory(BufferAllocator allocator, Location defaultLocation, String username, String password, boolean parallel) { - this.allocator = allocator; + public FlightClientFactory(Location defaultLocation, String username, String password, boolean parallel) { this.defaultLocation = defaultLocation; this.username = username; this.password = password.equals("$NULL$") ? null : password; @@ -60,4 +60,9 @@ public String getPassword() { public boolean isParallel() { return parallel; } + + @Override + public void close() throws Exception { + allocator.close(); + } } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightDataReader.java b/src/main/java/org/apache/arrow/flight/spark/FlightDataReader.java index 52b3145..9e6bd26 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightDataReader.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightDataReader.java @@ -26,6 +26,7 @@ import org.apache.arrow.flight.Ticket; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -33,24 +34,33 @@ import org.slf4j.LoggerFactory; public class FlightDataReader implements InputPartitionReader { - private final Logger logger = LoggerFactory.getLogger(FlightDataReader.class); + private static final Logger logger = LoggerFactory.getLogger(FlightDataReader.class); private final FlightClient client; private final FlightStream stream; private final BufferAllocator allocator; + private final FlightClientFactory clientFactory; + private final byte[] ticket; + private boolean parallel; public FlightDataReader( byte[] ticket, String defaultHost, - int defaultPort, String username, String password, boolean parallel) { + int defaultPort, + String username, + String password, + boolean parallel) { + this.parallel = parallel; this.allocator = new RootAllocator(); logger.warn("setting up a data reader at host {} and port {} with ticket {}", defaultHost, defaultPort, new String(ticket)); - client = FlightClient.builder(this.allocator, Location.forGrpcInsecure(defaultHost, defaultPort)).build(); //todo multiple locations & ssl - client.authenticateBasic(username, password); + clientFactory = new FlightClientFactory(Location.forGrpcInsecure(defaultHost, defaultPort), username, password, parallel); + client = clientFactory.apply(); + stream = client.getStream(new Ticket(ticket)); + this.ticket = ticket; if (parallel) { - Iterator res = client.doAction(new Action("PARALLEL")); - res.forEachRemaining(Object::toString); + logger.debug("doing create action for ticket {}", new String(ticket)); + client.doAction(new Action("create", ticket)).forEachRemaining(Object::toString); + logger.debug("completed create action for ticket {}", new String(ticket)); } - stream = client.getStream(new Ticket(ticket)); } @Override @@ -73,9 +83,10 @@ public ColumnarBatch get() { @Override public void close() throws IOException { try { - client.close(); - stream.close(); -// allocator.close(); + if (parallel) { + client.doAction(new Action("delete", ticket)).forEachRemaining(Object::toString); + } + AutoCloseables.close(client, stream, clientFactory, allocator); } catch (Exception e) { throw new IOException(e); } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightDataSourceReader.java b/src/main/java/org/apache/arrow/flight/spark/FlightDataSourceReader.java index 077f21f..afe549a 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightDataSourceReader.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightDataSourceReader.java @@ -15,19 +15,15 @@ */ package org.apache.arrow.flight.spark; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.stream.Collectors; -import org.apache.arrow.flight.Action; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.Result; import org.apache.arrow.flight.SchemaResult; -import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.FieldType; @@ -57,7 +53,7 @@ import scala.collection.JavaConversions; -public class FlightDataSourceReader implements SupportsScanColumnarBatch, SupportsPushDownFilters, SupportsPushDownRequiredColumns { +public class FlightDataSourceReader implements SupportsScanColumnarBatch, SupportsPushDownFilters, SupportsPushDownRequiredColumns, AutoCloseable { private static final Logger LOGGER = LoggerFactory.getLogger(FlightDataSourceReader.class); private static final Joiner WHERE_JOINER = Joiner.on(" and "); private static final Joiner PROJ_JOINER = Joiner.on(", "); @@ -69,12 +65,12 @@ public class FlightDataSourceReader implements SupportsScanColumnarBatch, Suppor private String sql; private Filter[] pushed; - public FlightDataSourceReader(DataSourceOptions dataSourceOptions, BufferAllocator allocator) { + public FlightDataSourceReader(DataSourceOptions dataSourceOptions) { defaultLocation = Location.forGrpcInsecure( dataSourceOptions.get("host").orElse("localhost"), dataSourceOptions.getInt("port", 47470) ); - clientFactory = new FlightClientFactory(allocator, + clientFactory = new FlightClientFactory( defaultLocation, dataSourceOptions.get("username").orElse("anonymous"), dataSourceOptions.get("password").orElse(null), @@ -171,6 +167,7 @@ private DataType sparkFromArrow(FieldType fieldType) { @Override public List> planBatchInputPartitions() { + System.out.println("planBatchInputPartitions"); return planBatchInputPartitionsParallel(); } @@ -317,4 +314,9 @@ public void pruneColumns(StructType requiredSchema) { } } } + + @Override + public void close() throws Exception { + clientFactory.close(); + } } diff --git a/src/main/java/org/apache/arrow/flight/spark/ModernArrowColumnVector.java b/src/main/java/org/apache/arrow/flight/spark/ModernArrowColumnVector.java index 28a76e9..189d8ef 100644 --- a/src/main/java/org/apache/arrow/flight/spark/ModernArrowColumnVector.java +++ b/src/main/java/org/apache/arrow/flight/spark/ModernArrowColumnVector.java @@ -32,19 +32,35 @@ package org.apache.arrow.flight.spark; -import io.netty.buffer.ArrowBuf; -import org.apache.arrow.vector.*; -import org.apache.arrow.vector.complex.*; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.holders.NullableVarCharHolder; - import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.arrow.ModernArrowUtils; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.UTF8String; +import io.netty.buffer.ArrowBuf; + /** * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not * supported. @@ -119,25 +135,33 @@ public double getDouble(int rowId) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { - if (isNullAt(rowId)) return null; + if (isNullAt(rowId)) { + return null; + } return accessor.getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int rowId) { - if (isNullAt(rowId)) return null; + if (isNullAt(rowId)) { + return null; + } return accessor.getUTF8String(rowId); } @Override public byte[] getBinary(int rowId) { - if (isNullAt(rowId)) return null; + if (isNullAt(rowId)) { + return null; + } return accessor.getBinary(rowId); } @Override public ColumnarArray getArray(int rowId) { - if (isNullAt(rowId)) return null; + if (isNullAt(rowId)) { + return null; + } return accessor.getArray(rowId); } @@ -147,7 +171,9 @@ public ColumnarMap getMap(int rowId) { } @Override - public ModernArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } + public ModernArrowColumnVector getChild(int ordinal) { + return childColumns[ordinal]; + } public ModernArrowColumnVector(ValueVector vector) { super(ModernArrowUtils.fromArrowField(vector.getField())); @@ -174,8 +200,12 @@ public ModernArrowColumnVector(ValueVector vector) { accessor = new BinaryAccessor((VarBinaryVector) vector); } else if (vector instanceof DateDayVector) { accessor = new DateAccessor((DateDayVector) vector); + } else if (vector instanceof DateMilliVector) { + accessor = new DateMilliAccessor((DateMilliVector) vector); } else if (vector instanceof TimeStampMicroTZVector) { accessor = new TimestampAccessor((TimeStampMicroTZVector) vector); + } else if (vector instanceof TimeStampMilliVector) { + accessor = new TimestampMilliAccessor((TimeStampMilliVector) vector); } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); @@ -188,6 +218,7 @@ public ModernArrowColumnVector(ValueVector vector) { childColumns[i] = new ModernArrowColumnVector(structVector.getVectorById(i)); } } else { + System.out.println(vector); throw new UnsupportedOperationException(); } } @@ -374,7 +405,9 @@ private static class DecimalAccessor extends ArrowVectorAccessor { @Override final Decimal getDecimal(int rowId, int precision, int scale) { - if (isNullAt(rowId)) return null; + if (isNullAt(rowId)) { + return null; + } return Decimal.apply(accessor.getObject(rowId), precision, scale); } } @@ -432,9 +465,26 @@ final int getInt(int rowId) { } } + private static class DateMilliAccessor extends ArrowVectorAccessor { + + private final DateMilliVector accessor; + private final double val = 1.0 / (24. * 60. * 60. * 1000.); + + DateMilliAccessor(DateMilliVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final int getInt(int rowId) { + System.out.println(accessor.get(rowId) + " " + (accessor.get(rowId) * val) + " " + val); + return (int) (accessor.get(rowId) * val); + } + } + private static class TimestampAccessor extends ArrowVectorAccessor { - private final TimeStampMicroTZVector accessor; + private final TimeStampVector accessor; TimestampAccessor(TimeStampMicroTZVector vector) { super(vector); @@ -447,6 +497,21 @@ final long getLong(int rowId) { } } + private static class TimestampMilliAccessor extends ArrowVectorAccessor { + + private final TimeStampVector accessor; + + TimestampMilliAccessor(TimeStampMilliVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final long getLong(int rowId) { + return accessor.get(rowId) * 1000; + } + } + private static class ArrayAccessor extends ArrowVectorAccessor { private final ListVector accessor; @@ -480,11 +545,10 @@ final ColumnarArray getArray(int rowId) { /** * Any call to "get" method will throw UnsupportedOperationException. - * + *

* Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses * getStruct() method defined in the parent class. Any call to "get" method in this class is a * bug in the code. - * */ private static class StructAccessor extends ArrowVectorAccessor { diff --git a/src/main/scala/org/apache/spark/sql/execution/arrow/ModernArrowUtils.scala b/src/main/scala/org/apache/spark/sql/execution/arrow/ModernArrowUtils.scala index aef1915..9bb953e 100644 --- a/src/main/scala/org/apache/spark/sql/execution/arrow/ModernArrowUtils.scala +++ b/src/main/scala/org/apache/spark/sql/execution/arrow/ModernArrowUtils.scala @@ -65,8 +65,8 @@ object ModernArrowUtils { case ArrowType.Utf8.INSTANCE => StringType case ArrowType.Binary.INSTANCE => BinaryType case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) - case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType - case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType + case date: ArrowType.Date if date.getUnit == DateUnit.DAY || date.getUnit == DateUnit.MILLISECOND => DateType + case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND || ts.getUnit == TimeUnit.MILLISECOND => TimestampType case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt") } diff --git a/src/test/java/org/apache/arrow/flight/spark/TestConnector.java b/src/test/java/org/apache/arrow/flight/spark/TestConnector.java index 861691f..42c2a2f 100644 --- a/src/test/java/org/apache/arrow/flight/spark/TestConnector.java +++ b/src/test/java/org/apache/arrow/flight/spark/TestConnector.java @@ -145,7 +145,7 @@ public void testProject() { @Test public void testParallel() { - String easySql = "select * from test.table"; + String easySql = "select * from \"@dremio\".tpch_spark limit 100000"; SizeConsumer c = new SizeConsumer(); csc.readSql(easySql, true).toLocalIterator().forEachRemaining(c); long width = c.width; From bbaf8cf66f1a08a0ab143652e51795d810f2af39 Mon Sep 17 00:00:00 2001 From: Ryan Murray Date: Mon, 18 May 2020 14:13:30 +0200 Subject: [PATCH 16/38] Cleanup and fixes for v1.0 * refactor code * fixed a few more bugs * upgraded arrow libraries * benchmarked --- README.md | 6 +- pom.xml | 39 ++--- .../arrow/flight/spark/DefaultSource.java | 35 ++++- ...ctor.java => FlightArrowColumnVector.java} | 24 ++-- .../flight/spark/FlightClientFactory.java | 16 +-- .../arrow/flight/spark/FlightDataReader.java | 63 ++++---- .../flight/spark/FlightDataReaderFactory.java | 26 +--- .../flight/spark/FlightDataSourceReader.java | 136 +++++++++++++++--- ...rrowUtils.scala => FlightArrowUtils.scala} | 6 +- 9 files changed, 231 insertions(+), 120 deletions(-) rename src/main/java/org/apache/arrow/flight/spark/{ModernArrowColumnVector.java => FlightArrowColumnVector.java} (95%) rename src/main/scala/org/apache/spark/sql/execution/arrow/{ModernArrowUtils.scala => FlightArrowUtils.scala} (97%) diff --git a/README.md b/README.md index 1a8b507..7910ffd 100644 --- a/README.md +++ b/README.md @@ -17,10 +17,6 @@ It currently supports: It currently lacks: * support for all Spark/Arrow data types and filters -* Strongly tied to [Dremio's flight endpoint](https://github.com/dremio-hub/dremio-flight-connector) and should be abstracted -to generic Flight sources -* Needs to be updated to support new features in Arrow 0.15.0 * write interface to use `DoPut` to write Spark dataframes back to an Arrow Flight endpoint * leverage the transactional capabilities of the Spark Source V2 interface -* proper benchmark test -* CI build & tests +* publish benchmark test diff --git a/pom.xml b/pom.xml index 3524c69..7628398 100644 --- a/pom.xml +++ b/pom.xml @@ -26,8 +26,8 @@ 1.0-SNAPSHOT - 0.15.1 - 2.4.4 + 0.17.0 + 2.4.5 1.7.25 @@ -52,8 +52,8 @@ maven-checkstyle-plugin 3.1.0 - src/main/checkstyle/checkstyle-config.xml - src/main/checkstyle/checkstyle-suppressions.xml + ${project.basedir}/src/main/checkstyle/checkstyle-config.xml + ${project.basedir}/src/main/checkstyle/checkstyle-suppressions.xml @@ -306,6 +306,9 @@ limitations under the License. net.alchim31.maven scala-maven-plugin + + false + scala-compile-first @@ -337,7 +340,8 @@ limitations under the License. - org.apache.arrow:arrow-flight:shaded + org.apache.arrow:flight-core + org.apache.arrow:flight-grpc org.apache.arrow:arrow-vector org.apache.arrow:arrow-format org.apache.arrow:arrow-memory @@ -351,6 +355,7 @@ limitations under the License. com.google.api.grpc:proto-google-common-protos com.google.protobuf:protobuf-java com.google.guava:guava + io.perfmark:perfmark-api io.netty:netty-transport-native-unix-common @@ -389,6 +394,9 @@ limitations under the License. + + + true shaded @@ -398,11 +406,6 @@ limitations under the License. - - - - - org.apache.spark spark-core_2.11 @@ -463,9 +466,13 @@ limitations under the License. org.apache.arrow - arrow-flight + flight-core + ${arrow.version} + + + org.apache.arrow + flight-grpc ${arrow.version} - shaded org.scala-lang @@ -473,12 +480,6 @@ limitations under the License. 2.11.6 - - org.scalatest - scalatest_2.11 - 2.2.5 - compile - org.slf4j jul-to-slf4j @@ -487,7 +488,7 @@ limitations under the License. org.apache.arrow - arrow-flight + flight-core ${arrow.version} tests test diff --git a/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java index f814aaf..38d89d7 100644 --- a/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java +++ b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java @@ -15,6 +15,10 @@ */ package org.apache.arrow.flight.spark; +import org.apache.arrow.flight.Location; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.ReadSupport; @@ -22,7 +26,36 @@ public class DefaultSource implements DataSourceV2, ReadSupport { + private SparkSession lazySpark; + private JavaSparkContext lazySparkContext; + public DataSourceReader createReader(DataSourceOptions dataSourceOptions) { - return new FlightDataSourceReader(dataSourceOptions); + Location defaultLocation = Location.forGrpcInsecure( + dataSourceOptions.get("host").orElse("localhost"), + dataSourceOptions.getInt("port", 47470) + ); + String sql = dataSourceOptions.get("path").orElse(""); + FlightDataSourceReader.FactoryOptions options = new FlightDataSourceReader.FactoryOptions( + defaultLocation, + sql, + dataSourceOptions.get("username").orElse("anonymous"), + dataSourceOptions.get("password").orElse(null), + dataSourceOptions.getBoolean("parallel", false), null); + Broadcast bOptions = lazySparkContext().broadcast(options); + return new FlightDataSourceReader(bOptions); + } + + private SparkSession lazySparkSession() { + if (lazySpark == null) { + this.lazySpark = SparkSession.builder().getOrCreate(); + } + return lazySpark; + } + + private JavaSparkContext lazySparkContext() { + if (lazySparkContext == null) { + this.lazySparkContext = new JavaSparkContext(lazySparkSession().sparkContext()); + } + return lazySparkContext; } } diff --git a/src/main/java/org/apache/arrow/flight/spark/ModernArrowColumnVector.java b/src/main/java/org/apache/arrow/flight/spark/FlightArrowColumnVector.java similarity index 95% rename from src/main/java/org/apache/arrow/flight/spark/ModernArrowColumnVector.java rename to src/main/java/org/apache/arrow/flight/spark/FlightArrowColumnVector.java index 189d8ef..60bfa07 100644 --- a/src/main/java/org/apache/arrow/flight/spark/ModernArrowColumnVector.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightArrowColumnVector.java @@ -51,8 +51,7 @@ import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.holders.NullableVarCharHolder; -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.execution.arrow.ModernArrowUtils; +import org.apache.spark.sql.execution.arrow.FlightArrowUtils; import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarArray; @@ -63,13 +62,12 @@ /** * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not - * supported. + * supported. This is a copy of ArrowColumnVector with added support for DateMilli and TimestampMilli */ -@InterfaceStability.Evolving -public final class ModernArrowColumnVector extends ColumnVector { +public final class FlightArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; - private ModernArrowColumnVector[] childColumns; + private FlightArrowColumnVector[] childColumns; @Override public boolean hasNull() { @@ -171,12 +169,12 @@ public ColumnarMap getMap(int rowId) { } @Override - public ModernArrowColumnVector getChild(int ordinal) { + public FlightArrowColumnVector getChild(int ordinal) { return childColumns[ordinal]; } - public ModernArrowColumnVector(ValueVector vector) { - super(ModernArrowUtils.fromArrowField(vector.getField())); + public FlightArrowColumnVector(ValueVector vector) { + super(FlightArrowUtils.fromArrowField(vector.getField())); if (vector instanceof BitVector) { accessor = new BooleanAccessor((BitVector) vector); @@ -213,9 +211,9 @@ public ModernArrowColumnVector(ValueVector vector) { StructVector structVector = (StructVector) vector; accessor = new StructAccessor(structVector); - childColumns = new ModernArrowColumnVector[structVector.size()]; + childColumns = new FlightArrowColumnVector[structVector.size()]; for (int i = 0; i < childColumns.length; ++i) { - childColumns[i] = new ModernArrowColumnVector(structVector.getVectorById(i)); + childColumns[i] = new FlightArrowColumnVector(structVector.getVectorById(i)); } } else { System.out.println(vector); @@ -515,12 +513,12 @@ final long getLong(int rowId) { private static class ArrayAccessor extends ArrowVectorAccessor { private final ListVector accessor; - private final ModernArrowColumnVector arrayData; + private final FlightArrowColumnVector arrayData; ArrayAccessor(ListVector vector) { super(vector); this.accessor = vector; - this.arrayData = new ModernArrowColumnVector(vector.getDataVector()); + this.arrayData = new FlightArrowColumnVector(vector.getDataVector()); } @Override diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java index a8d09c4..c060743 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java @@ -34,7 +34,7 @@ public class FlightClientFactory implements AutoCloseable { public FlightClientFactory(Location defaultLocation, String username, String password, boolean parallel) { this.defaultLocation = defaultLocation; this.username = username; - this.password = password.equals("$NULL$") ? null : password; + this.password = (password == null || password.equals("$NULL$")) ? null : password; this.parallel = parallel; } @@ -49,20 +49,8 @@ public FlightClient apply() { } - public String getUsername() { - return username; - } - - public String getPassword() { - return password; - } - - public boolean isParallel() { - return parallel; - } - @Override - public void close() throws Exception { + public void close() { allocator.close(); } } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightDataReader.java b/src/main/java/org/apache/arrow/flight/spark/FlightDataReader.java index 9e6bd26..ed0bea4 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightDataReader.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightDataReader.java @@ -16,17 +16,16 @@ package org.apache.arrow.flight.spark; import java.io.IOException; -import java.util.Iterator; import org.apache.arrow.flight.Action; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.Result; import org.apache.arrow.flight.Ticket; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.AutoCloseables; +import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -35,45 +34,56 @@ public class FlightDataReader implements InputPartitionReader { private static final Logger logger = LoggerFactory.getLogger(FlightDataReader.class); - private final FlightClient client; - private final FlightStream stream; - private final BufferAllocator allocator; - private final FlightClientFactory clientFactory; - private final byte[] ticket; + private FlightClient client; + private FlightStream stream; + private BufferAllocator allocator = null; + private FlightClientFactory clientFactory; + private final Ticket ticket; + private final Broadcast options; + private final Location location; private boolean parallel; - public FlightDataReader( - byte[] ticket, - String defaultHost, - int defaultPort, - String username, - String password, - boolean parallel) { - this.parallel = parallel; + public FlightDataReader(Broadcast options) { + this.options = options; + this.location = Location.forGrpcInsecure(options.value().getHost(), options.value().getPort()); + this.ticket = new Ticket(options.value().getTicket()); + } + + private void start() { + if (allocator != null) { + return; + } + FlightDataSourceReader.FactoryOptions options = this.options.getValue(); + this.parallel = options.isParallel(); this.allocator = new RootAllocator(); - logger.warn("setting up a data reader at host {} and port {} with ticket {}", defaultHost, defaultPort, new String(ticket)); - clientFactory = new FlightClientFactory(Location.forGrpcInsecure(defaultHost, defaultPort), username, password, parallel); + logger.warn("setting up a data reader at host {} and port {} with ticket {}", options.getHost(), options.getPort(), new String(ticket.getBytes())); + clientFactory = new FlightClientFactory(location, options.getUsername(), options.getPassword(), parallel); client = clientFactory.apply(); - stream = client.getStream(new Ticket(ticket)); - this.ticket = ticket; + stream = client.getStream(ticket); if (parallel) { - logger.debug("doing create action for ticket {}", new String(ticket)); - client.doAction(new Action("create", ticket)).forEachRemaining(Object::toString); - logger.debug("completed create action for ticket {}", new String(ticket)); + logger.debug("doing create action for ticket {}", new String(ticket.getBytes())); + client.doAction(new Action("create", ticket.getBytes())).forEachRemaining(Object::toString); + logger.debug("completed create action for ticket {}", new String(ticket.getBytes())); } } @Override public boolean next() throws IOException { - return stream.next(); + start(); + try { + return stream.next(); + } catch (Throwable t) { + throw new IOException(t); + } } @Override public ColumnarBatch get() { + start(); ColumnarBatch batch = new ColumnarBatch( stream.getRoot().getFieldVectors() .stream() - .map(ModernArrowColumnVector::new) + .map(FlightArrowColumnVector::new) .toArray(ColumnVector[]::new) ); batch.setNumRows(stream.getRoot().getRowCount()); @@ -84,9 +94,10 @@ public ColumnarBatch get() { public void close() throws IOException { try { if (parallel) { - client.doAction(new Action("delete", ticket)).forEachRemaining(Object::toString); + client.doAction(new Action("delete", ticket.getBytes())).forEachRemaining(Object::toString); } - AutoCloseables.close(client, stream, clientFactory, allocator); + AutoCloseables.close(stream, client, clientFactory, allocator); + allocator.close(); } catch (Exception e) { throw new IOException(e); } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightDataReaderFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightDataReaderFactory.java index 667222f..12ad028 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightDataReaderFactory.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightDataReaderFactory.java @@ -15,39 +15,27 @@ */ package org.apache.arrow.flight.spark; +import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; import org.apache.spark.sql.vectorized.ColumnarBatch; public class FlightDataReaderFactory implements InputPartition { - private byte[] ticket; - private final String defaultHost; - private final int defaultPort; - private final String username; - private final String password; - private boolean parallel; + private final Broadcast options; - public FlightDataReaderFactory( - byte[] ticket, - String defaultHost, - int defaultPort, String username, String password, boolean parallel) { - this.ticket = ticket; - this.defaultHost = defaultHost; - this.defaultPort = defaultPort; - this.username = username; - this.password = password; - this.parallel = parallel; + public FlightDataReaderFactory(Broadcast options) { + this.options = options; } @Override - public String[] preferredLocations() { - return new String[]{defaultHost}; + public String[] preferredLocations() { + return new String[]{options.value().getHost()}; } @Override public InputPartitionReader createPartitionReader() { - return new FlightDataReader(ticket, defaultHost, defaultPort, username, password, parallel); + return new FlightDataReader(options); } } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightDataSourceReader.java b/src/main/java/org/apache/arrow/flight/spark/FlightDataSourceReader.java index afe549a..6925b36 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightDataSourceReader.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightDataSourceReader.java @@ -15,8 +15,11 @@ */ package org.apache.arrow.flight.spark; +import java.io.Serializable; +import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.StringJoiner; import java.util.stream.Collectors; import org.apache.arrow.flight.FlightClient; @@ -27,6 +30,9 @@ import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.sources.EqualTo; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; @@ -34,7 +40,6 @@ import org.apache.spark.sql.sources.IsNotNull; import org.apache.spark.sql.sources.LessThan; import org.apache.spark.sql.sources.LessThanOrEqual; -import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.SupportsPushDownFilters; import org.apache.spark.sql.sources.v2.reader.SupportsPushDownRequiredColumns; @@ -57,26 +62,27 @@ public class FlightDataSourceReader implements SupportsScanColumnarBatch, Suppor private static final Logger LOGGER = LoggerFactory.getLogger(FlightDataSourceReader.class); private static final Joiner WHERE_JOINER = Joiner.on(" and "); private static final Joiner PROJ_JOINER = Joiner.on(", "); + private final Location defaultLocation; private SchemaResult info; private FlightDescriptor descriptor; private StructType schema; - private final Location defaultLocation; private final FlightClientFactory clientFactory; private String sql; + private final Broadcast dataSourceOptions; private Filter[] pushed; + private SparkSession lazySpark; + private JavaSparkContext lazySparkContext; - public FlightDataSourceReader(DataSourceOptions dataSourceOptions) { - defaultLocation = Location.forGrpcInsecure( - dataSourceOptions.get("host").orElse("localhost"), - dataSourceOptions.getInt("port", 47470) - ); + public FlightDataSourceReader(Broadcast dataSourceOptions) { clientFactory = new FlightClientFactory( - defaultLocation, - dataSourceOptions.get("username").orElse("anonymous"), - dataSourceOptions.get("password").orElse(null), - dataSourceOptions.getBoolean("parallel", false) + dataSourceOptions.value().getLocation(), + dataSourceOptions.value().getUsername(), + dataSourceOptions.value().getPassword(), + dataSourceOptions.value().isParallel() ); - sql = dataSourceOptions.get("path").orElse(""); + defaultLocation = dataSourceOptions.value().getLocation(); + sql = dataSourceOptions.value().getSql(); + this.dataSourceOptions = dataSourceOptions; descriptor = getDescriptor(sql); try (FlightClient client = clientFactory.apply()) { info = client.getSchema(descriptor); @@ -167,7 +173,6 @@ private DataType sparkFromArrow(FieldType fieldType) { @Override public List> planBatchInputPartitions() { - System.out.println("planBatchInputPartitions"); return planBatchInputPartitionsParallel(); } @@ -187,17 +192,103 @@ private List> planBatchInputPartitionsSerial(Fligh Location location = (endpoint.getLocations().isEmpty()) ? Location.forGrpcInsecure(defaultLocation.getUri().getHost(), defaultLocation.getUri().getPort()) : endpoint.getLocations().get(0); - return new FlightDataReaderFactory(endpoint.getTicket().getBytes(), - location.getUri().getHost(), - location.getUri().getPort(), - clientFactory.getUsername(), - clientFactory.getPassword(), - clientFactory.isParallel()); + FactoryOptions options = dataSourceOptions.value().copy(location, endpoint.getTicket().getBytes()); + LOGGER.warn("X1 {}", dataSourceOptions.value()); + return new FlightDataReaderFactory(lazySparkContext().broadcast(options)); }).collect(Collectors.toList()); LOGGER.info("Created {} batches from arrow endpoints", batches.size()); return batches; } + private SparkSession lazySparkSession() { + if (lazySpark == null) { + this.lazySpark = SparkSession.builder().getOrCreate(); + } + return lazySpark; + } + + private JavaSparkContext lazySparkContext() { + if (lazySparkContext == null) { + this.lazySparkContext = new JavaSparkContext(lazySparkSession().sparkContext()); + } + return lazySparkContext; + } + + static class FactoryOptions implements Serializable { + private final String host; + private final int port; + private final String sql; + private final String username; + private final String password; + private final boolean parallel; + private final byte[] ticket; + + FactoryOptions(Location location, String sql, String username, String password, boolean parallel, byte[] ticket) { + this.host = location.getUri().getHost(); + this.port = location.getUri().getPort(); + this.sql = sql; + this.username = username; + this.password = password; + this.parallel = parallel; + this.ticket = ticket; + } + + public String getUsername() { + return username; + } + + public String getPassword() { + return password; + } + + public boolean isParallel() { + return parallel; + } + + public Location getLocation() { + return Location.forGrpcInsecure(host, port); + } + + public String getHost() { + return host; + } + + public int getPort() { + return port; + } + + public String getSql() { + return sql; + } + + @Override + public String toString() { + return new StringJoiner(", ", FactoryOptions.class.getSimpleName() + "[", "]") + .add("host='" + host + "'") + .add("port=" + port) + .add("sql='" + sql + "'") + .add("username='" + username + "'") + .add("password='" + password + "'") + .add("parallel=" + parallel) + .add("ticket=" + Arrays.toString(ticket)) + .toString(); + } + + public byte[] getTicket() { + return ticket; + } + + FactoryOptions copy(Location location, byte[] ticket) { + return new FactoryOptions( + location, + sql, + username, + password, + parallel, + ticket); + } + } + @Override public Filter[] pushFilters(Filter[] filters) { List notPushed = Lists.newArrayList(); @@ -249,6 +340,7 @@ private String generateWhereClause(List pushed) { } else if (filter instanceof LessThanOrEqual) { filterStr.add(String.format("\"%s\" <= %s", ((LessThanOrEqual) filter).attribute(), valueToString(((LessThanOrEqual) filter).value()))); } + //todo fill out rest of Filter types } return WHERE_JOINER.join(filterStr); } @@ -295,8 +387,8 @@ public void pruneColumns(StructType requiredSchema) { StructType schema = readSchema(); List fields = Lists.newArrayList(); List fieldsLeft = Lists.newArrayList(); - Map fieldNames = JavaConversions.seqAsJavaList(schema.toSeq()).stream().collect(Collectors.toMap(StructField::name, f -> f)); - for (StructField field : JavaConversions.seqAsJavaList(requiredSchema.toSeq())) { + Map fieldNames = JavaConversions.seqAsJavaList(schema.toSeq()).stream().collect(Collectors.toMap(StructField::name, f -> f)); + for (StructField field : JavaConversions.seqAsJavaList(requiredSchema.toSeq())) { String name = field.name(); StructField f = fieldNames.remove(name); if (f != null) { @@ -316,7 +408,7 @@ public void pruneColumns(StructType requiredSchema) { } @Override - public void close() throws Exception { + public void close() { clientFactory.close(); } } diff --git a/src/main/scala/org/apache/spark/sql/execution/arrow/ModernArrowUtils.scala b/src/main/scala/org/apache/spark/sql/execution/arrow/FlightArrowUtils.scala similarity index 97% rename from src/main/scala/org/apache/spark/sql/execution/arrow/ModernArrowUtils.scala rename to src/main/scala/org/apache/spark/sql/execution/arrow/FlightArrowUtils.scala index 9bb953e..d8c210e 100644 --- a/src/main/scala/org/apache/spark/sql/execution/arrow/ModernArrowUtils.scala +++ b/src/main/scala/org/apache/spark/sql/execution/arrow/FlightArrowUtils.scala @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.execution.arrow import org.apache.arrow.memory.RootAllocator @@ -22,7 +23,10 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import scala.collection.JavaConverters._ -object ModernArrowUtils { +/** + * FlightArrowUtils is a copy of ArrowUtils with extra support for DateMilli and TimestampMilli + */ +object FlightArrowUtils { val rootAllocator = new RootAllocator(Long.MaxValue) From a7da568121f99b2c0d66aac8f206b48a833c4e00 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 13 Oct 2020 13:23:24 +0000 Subject: [PATCH 17/38] Bump junit from 4.11 to 4.13.1 Bumps [junit](https://github.com/junit-team/junit4) from 4.11 to 4.13.1. - [Release notes](https://github.com/junit-team/junit4/releases) - [Changelog](https://github.com/junit-team/junit4/blob/main/doc/ReleaseNotes4.11.md) - [Commits](https://github.com/junit-team/junit4/compare/r4.11...r4.13.1) Signed-off-by: dependabot[bot] --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 7628398..ee90028 100644 --- a/pom.xml +++ b/pom.xml @@ -522,7 +522,7 @@ limitations under the License. junit junit - 4.11 + 4.13.1 test From 2452c50a879d31e7bdd398af9432cb0d533c8340 Mon Sep 17 00:00:00 2001 From: Doron Chen Date: Thu, 14 Jan 2021 19:38:18 +0200 Subject: [PATCH 18/38] add support for TimeStampMicroVector in FlightArrowColumnVector.java Signed-off-by: Doron Chen --- .../flight/spark/FlightArrowColumnVector.java | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightArrowColumnVector.java b/src/main/java/org/apache/arrow/flight/spark/FlightArrowColumnVector.java index 60bfa07..eccb3a7 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightArrowColumnVector.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightArrowColumnVector.java @@ -41,6 +41,7 @@ import org.apache.arrow.vector.Float8Vector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeStampMicroVector; import org.apache.arrow.vector.TimeStampMicroTZVector; import org.apache.arrow.vector.TimeStampMilliVector; import org.apache.arrow.vector.TimeStampVector; @@ -200,8 +201,10 @@ public FlightArrowColumnVector(ValueVector vector) { accessor = new DateAccessor((DateDayVector) vector); } else if (vector instanceof DateMilliVector) { accessor = new DateMilliAccessor((DateMilliVector) vector); + } else if (vector instanceof TimeStampMicroVector) { + accessor = new TimestampMicroAccessor((TimeStampMicroVector) vector); } else if (vector instanceof TimeStampMicroTZVector) { - accessor = new TimestampAccessor((TimeStampMicroTZVector) vector); + accessor = new TimestampMicroTZAccessor((TimeStampMicroTZVector) vector); } else if (vector instanceof TimeStampMilliVector) { accessor = new TimestampMilliAccessor((TimeStampMilliVector) vector); } else if (vector instanceof ListVector) { @@ -480,11 +483,26 @@ final int getInt(int rowId) { } } - private static class TimestampAccessor extends ArrowVectorAccessor { + private static class TimestampMicroAccessor extends ArrowVectorAccessor { private final TimeStampVector accessor; - TimestampAccessor(TimeStampMicroTZVector vector) { + TimestampMicroAccessor(TimeStampMicroVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final long getLong(int rowId) { + return accessor.get(rowId); + } + } + + private static class TimestampMicroTZAccessor extends ArrowVectorAccessor { + + private final TimeStampVector accessor; + + TimestampMicroTZAccessor(TimeStampMicroTZVector vector) { super(vector); this.accessor = vector; } From 0c23d9750d7b8af380be5492ab7afc98cfb417ba Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Wed, 13 Apr 2022 16:30:42 -0400 Subject: [PATCH 19/38] Major refactor for spark 3; Compiling but test still failing. --- .mvn/wrapper/MavenWrapperDownloader.java | 8 +- .mvn/wrapper/maven-wrapper.properties | 19 +- mvnw | 37 +- mvnw.cmd | 37 +- pom.xml | 25 +- .../arrow/flight/spark/DefaultSource.java | 61 --- .../flight/spark/FlightArrowColumnVector.java | 2 +- .../flight/spark/FlightClientFactory.java | 26 +- .../spark/FlightColumnarPartitionReader.java | 48 ++ .../arrow/flight/spark/FlightDataReader.java | 105 ----- .../flight/spark/FlightDataReaderFactory.java | 41 -- .../arrow/flight/spark/FlightDataSource.java | 59 +++ .../flight/spark/FlightDataSourceReader.java | 414 ------------------ .../arrow/flight/spark/FlightPartition.java | 21 + .../flight/spark/FlightPartitionReader.java | 102 +++++ .../spark/FlightPartitionReaderFactory.java | 46 ++ .../apache/arrow/flight/spark/FlightScan.java | 37 ++ .../arrow/flight/spark/FlightScanBuilder.java | 275 ++++++++++++ .../arrow/flight/spark/FlightTable.java | 53 +++ .../execution/arrow/FlightArrowUtils.scala | 7 +- 20 files changed, 719 insertions(+), 704 deletions(-) delete mode 100644 src/main/java/org/apache/arrow/flight/spark/DefaultSource.java create mode 100644 src/main/java/org/apache/arrow/flight/spark/FlightColumnarPartitionReader.java delete mode 100644 src/main/java/org/apache/arrow/flight/spark/FlightDataReader.java delete mode 100644 src/main/java/org/apache/arrow/flight/spark/FlightDataReaderFactory.java create mode 100644 src/main/java/org/apache/arrow/flight/spark/FlightDataSource.java delete mode 100644 src/main/java/org/apache/arrow/flight/spark/FlightDataSourceReader.java create mode 100644 src/main/java/org/apache/arrow/flight/spark/FlightPartition.java create mode 100644 src/main/java/org/apache/arrow/flight/spark/FlightPartitionReader.java create mode 100644 src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java create mode 100644 src/main/java/org/apache/arrow/flight/spark/FlightScan.java create mode 100644 src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java create mode 100644 src/main/java/org/apache/arrow/flight/spark/FlightTable.java diff --git a/.mvn/wrapper/MavenWrapperDownloader.java b/.mvn/wrapper/MavenWrapperDownloader.java index b4e9919..e76d1f3 100644 --- a/.mvn/wrapper/MavenWrapperDownloader.java +++ b/.mvn/wrapper/MavenWrapperDownloader.java @@ -1,11 +1,11 @@ /* - * Copyright (C) 2019 Ryan Murray + * Copyright 2007-present the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -20,12 +20,12 @@ public class MavenWrapperDownloader { - private static final String WRAPPER_VERSION = "0.5.4"; + private static final String WRAPPER_VERSION = "0.5.6"; /** * Default URL to download the maven-wrapper.jar from, if no 'downloadUrl' is provided. */ private static final String DEFAULT_DOWNLOAD_URL = "https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/" - + WRAPPER_VERSION + "/maven-wrapper-" + WRAPPER_VERSION + " .jar"; + + WRAPPER_VERSION + "/maven-wrapper-" + WRAPPER_VERSION + ".jar"; /** * Path to the maven-wrapper.properties file, which might contain a downloadUrl property to diff --git a/.mvn/wrapper/maven-wrapper.properties b/.mvn/wrapper/maven-wrapper.properties index dfbfc2e..2743cab 100644 --- a/.mvn/wrapper/maven-wrapper.properties +++ b/.mvn/wrapper/maven-wrapper.properties @@ -1,18 +1,3 @@ -# -# Copyright (C) 2019 Ryan Murray -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# +distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.6.3/apache-maven-3.6.3-bin.zip +wrapperUrl=https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.5.6/maven-wrapper-0.5.6.jar -distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.6.0/apache-maven-3.6.0-bin.zip -wrapperUrl=https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.5.4/maven-wrapper-0.5.4.jar diff --git a/mvnw b/mvnw index 4bd1977..a16b543 100755 --- a/mvnw +++ b/mvnw @@ -1,22 +1,25 @@ #!/bin/sh +# ---------------------------------------------------------------------------- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at # -# Copyright (C) 2019 Ryan Murray -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# https://www.apache.org/licenses/LICENSE-2.0 # +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ---------------------------------------------------------------------------- # ---------------------------------------------------------------------------- -# Maven2 Start Up Batch script +# Maven Start Up Batch script # # Required ENV vars: # ------------------ @@ -209,9 +212,9 @@ else echo "Couldn't find .mvn/wrapper/maven-wrapper.jar, downloading it ..." fi if [ -n "$MVNW_REPOURL" ]; then - jarUrl="$MVNW_REPOURL/io/takari/maven-wrapper/0.5.4/maven-wrapper-0.5.4.jar" + jarUrl="$MVNW_REPOURL/io/takari/maven-wrapper/0.5.6/maven-wrapper-0.5.6.jar" else - jarUrl="https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.5.4/maven-wrapper-0.5.4.jar" + jarUrl="https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.5.6/maven-wrapper-0.5.6.jar" fi while IFS="=" read key value; do case "$key" in (wrapperUrl) jarUrl="$value"; break ;; @@ -243,7 +246,7 @@ else else curl --user $MVNW_USERNAME:$MVNW_PASSWORD -o "$wrapperJarPath" "$jarUrl" -f fi - + else if [ "$MVNW_VERBOSE" = true ]; then echo "Falling back to using Java to download" diff --git a/mvnw.cmd b/mvnw.cmd index 574d165..c8d4337 100644 --- a/mvnw.cmd +++ b/mvnw.cmd @@ -1,21 +1,24 @@ +@REM ---------------------------------------------------------------------------- +@REM Licensed to the Apache Software Foundation (ASF) under one +@REM or more contributor license agreements. See the NOTICE file +@REM distributed with this work for additional information +@REM regarding copyright ownership. The ASF licenses this file +@REM to you under the Apache License, Version 2.0 (the +@REM "License"); you may not use this file except in compliance +@REM with the License. You may obtain a copy of the License at @REM -@REM Copyright (C) 2019 Ryan Murray -@REM -@REM Licensed under the Apache License, Version 2.0 (the "License"); -@REM you may not use this file except in compliance with the License. -@REM You may obtain a copy of the License at -@REM -@REM http://www.apache.org/licenses/LICENSE-2.0 -@REM -@REM Unless required by applicable law or agreed to in writing, software -@REM distributed under the License is distributed on an "AS IS" BASIS, -@REM WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -@REM See the License for the specific language governing permissions and -@REM limitations under the License. +@REM https://www.apache.org/licenses/LICENSE-2.0 @REM +@REM Unless required by applicable law or agreed to in writing, +@REM software distributed under the License is distributed on an +@REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +@REM KIND, either express or implied. See the License for the +@REM specific language governing permissions and limitations +@REM under the License. +@REM ---------------------------------------------------------------------------- @REM ---------------------------------------------------------------------------- -@REM Maven2 Start Up Batch script +@REM Maven Start Up Batch script @REM @REM Required ENV vars: @REM JAVA_HOME - location of a JDK home dir @@ -23,7 +26,7 @@ @REM Optional ENV vars @REM M2_HOME - location of maven2's installed home dir @REM MAVEN_BATCH_ECHO - set to 'on' to enable the echoing of the batch commands -@REM MAVEN_BATCH_PAUSE - set to 'on' to wait for a key stroke before ending +@REM MAVEN_BATCH_PAUSE - set to 'on' to wait for a keystroke before ending @REM MAVEN_OPTS - parameters passed to the Java VM when running Maven @REM e.g. to debug Maven itself, use @REM set MAVEN_OPTS=-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=y,address=8000 @@ -117,7 +120,7 @@ SET MAVEN_JAVA_EXE="%JAVA_HOME%\bin\java.exe" set WRAPPER_JAR="%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.jar" set WRAPPER_LAUNCHER=org.apache.maven.wrapper.MavenWrapperMain -set DOWNLOAD_URL="https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.5.4/maven-wrapper-0.5.4.jar" +set DOWNLOAD_URL="https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.5.6/maven-wrapper-0.5.6.jar" FOR /F "tokens=1,2 delims==" %%A IN ("%MAVEN_PROJECTBASEDIR%\.mvn\wrapper\maven-wrapper.properties") DO ( IF "%%A"=="wrapperUrl" SET DOWNLOAD_URL=%%B @@ -131,7 +134,7 @@ if exist %WRAPPER_JAR% ( ) ) else ( if not "%MVNW_REPOURL%" == "" ( - SET DOWNLOAD_URL="%MVNW_REPOURL%/io/takari/maven-wrapper/0.5.4/maven-wrapper-0.5.4.jar" + SET DOWNLOAD_URL="%MVNW_REPOURL%/io/takari/maven-wrapper/0.5.6/maven-wrapper-0.5.6.jar" ) if "%MVNW_VERBOSE%" == "true" ( echo Couldn't find %WRAPPER_JAR%, downloading it ... diff --git a/pom.xml b/pom.xml index ee90028..ddb285c 100644 --- a/pom.xml +++ b/pom.xml @@ -26,9 +26,12 @@ 1.0-SNAPSHOT - 0.17.0 - 2.4.5 + + 2.12.14 + 7.0.0 + 3.2.1 1.7.25 + 2.4.4 @@ -306,6 +309,7 @@ limitations under the License. net.alchim31.maven scala-maven-plugin + 4.6.1 false @@ -408,7 +412,7 @@ limitations under the License. org.apache.spark - spark-core_2.11 + spark-core_2.12 ${spark.version} @@ -435,7 +439,7 @@ limitations under the License. org.apache.spark - spark-sql_2.11 + spark-sql_2.12 ${spark.version} @@ -464,6 +468,17 @@ limitations under the License. + + + com.fasterxml.jackson.core + jackson-core + ${jackson-core.version} + + + com.fasterxml.jackson.core + jackson-databind + ${jackson-core.version} + org.apache.arrow flight-core @@ -477,7 +492,7 @@ limitations under the License. org.scala-lang scala-library - 2.11.6 + ${scala.version} diff --git a/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java deleted file mode 100644 index 38d89d7..0000000 --- a/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright (C) 2019 Ryan Murray - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.flight.spark; - -import org.apache.arrow.flight.Location; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.broadcast.Broadcast; -import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; - -public class DefaultSource implements DataSourceV2, ReadSupport { - - private SparkSession lazySpark; - private JavaSparkContext lazySparkContext; - - public DataSourceReader createReader(DataSourceOptions dataSourceOptions) { - Location defaultLocation = Location.forGrpcInsecure( - dataSourceOptions.get("host").orElse("localhost"), - dataSourceOptions.getInt("port", 47470) - ); - String sql = dataSourceOptions.get("path").orElse(""); - FlightDataSourceReader.FactoryOptions options = new FlightDataSourceReader.FactoryOptions( - defaultLocation, - sql, - dataSourceOptions.get("username").orElse("anonymous"), - dataSourceOptions.get("password").orElse(null), - dataSourceOptions.getBoolean("parallel", false), null); - Broadcast bOptions = lazySparkContext().broadcast(options); - return new FlightDataSourceReader(bOptions); - } - - private SparkSession lazySparkSession() { - if (lazySpark == null) { - this.lazySpark = SparkSession.builder().getOrCreate(); - } - return lazySpark; - } - - private JavaSparkContext lazySparkContext() { - if (lazySparkContext == null) { - this.lazySparkContext = new JavaSparkContext(lazySparkSession().sparkContext()); - } - return lazySparkContext; - } -} diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightArrowColumnVector.java b/src/main/java/org/apache/arrow/flight/spark/FlightArrowColumnVector.java index eccb3a7..54dff3e 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightArrowColumnVector.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightArrowColumnVector.java @@ -52,6 +52,7 @@ import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.holders.NullableVarCharHolder; +import org.apache.arrow.memory.ArrowBuf; import org.apache.spark.sql.execution.arrow.FlightArrowUtils; import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.vectorized.ColumnVector; @@ -59,7 +60,6 @@ import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.UTF8String; -import io.netty.buffer.ArrowBuf; /** * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java index c060743..e577ed5 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java @@ -15,38 +15,30 @@ */ package org.apache.arrow.flight.spark; -import java.util.Iterator; +import java.io.InputStream; +import java.util.Optional; -import org.apache.arrow.flight.Action; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.Result; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; public class FlightClientFactory implements AutoCloseable { private final BufferAllocator allocator = new RootAllocator(); private final Location defaultLocation; - private final String username; - private final String password; - private final boolean parallel; + private final Optional trustedCertificates; - public FlightClientFactory(Location defaultLocation, String username, String password, boolean parallel) { + public FlightClientFactory(Location defaultLocation, Optional trustedCertificates) { this.defaultLocation = defaultLocation; - this.username = username; - this.password = (password == null || password.equals("$NULL$")) ? null : password; - this.parallel = parallel; + this.trustedCertificates = trustedCertificates; } public FlightClient apply() { - FlightClient client = FlightClient.builder(allocator, defaultLocation).build(); - client.authenticateBasic(username, password); - if (parallel) { - Iterator res = client.doAction(new Action("PARALLEL")); - res.forEachRemaining(Object::toString); + FlightClient.Builder builder = FlightClient.builder(allocator, defaultLocation); + if (trustedCertificates.isPresent()) { + builder.trustedCertificates(trustedCertificates.get()); } - return client; - + return builder.build(); } @Override diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightColumnarPartitionReader.java b/src/main/java/org/apache/arrow/flight/spark/FlightColumnarPartitionReader.java new file mode 100644 index 0000000..0a90a9e --- /dev/null +++ b/src/main/java/org/apache/arrow/flight/spark/FlightColumnarPartitionReader.java @@ -0,0 +1,48 @@ +package org.apache.arrow.flight.spark; + +import java.io.IOException; + +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.arrow.flight.FlightStream; +import org.apache.spark.sql.vectorized.ColumnVector; + +public class FlightColumnarPartitionReader implements PartitionReader { + private final FlightStream stream; + + public FlightColumnarPartitionReader(FlightStream stream) { + this.stream = stream; + } + + @Override + public void close() throws IOException { + try { + stream.close(); + } catch (Exception e) { + throw new IOException(e); + } + } + + // This is written this way because the Spark interface iterates in a different way. + // E.g., .next() -> .get() vs. .hasNext() -> .next() + @Override + public boolean next() throws IOException { + try { + return stream.next(); + } catch (RuntimeException e) { + throw new IOException(e); + } + } + + @Override + public ColumnarBatch get() { + ColumnarBatch batch = new ColumnarBatch( + stream.getRoot().getFieldVectors() + .stream() + .map(FlightArrowColumnVector::new) + .toArray(ColumnVector[]::new) + ); + batch.setNumRows(stream.getRoot().getRowCount()); + return batch; + } +} diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightDataReader.java b/src/main/java/org/apache/arrow/flight/spark/FlightDataReader.java deleted file mode 100644 index ed0bea4..0000000 --- a/src/main/java/org/apache/arrow/flight/spark/FlightDataReader.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright (C) 2019 Ryan Murray - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.flight.spark; - -import java.io.IOException; - -import org.apache.arrow.flight.Action; -import org.apache.arrow.flight.FlightClient; -import org.apache.arrow.flight.FlightStream; -import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.Ticket; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.util.AutoCloseables; -import org.apache.spark.broadcast.Broadcast; -import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; -import org.apache.spark.sql.vectorized.ColumnVector; -import org.apache.spark.sql.vectorized.ColumnarBatch; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class FlightDataReader implements InputPartitionReader { - private static final Logger logger = LoggerFactory.getLogger(FlightDataReader.class); - private FlightClient client; - private FlightStream stream; - private BufferAllocator allocator = null; - private FlightClientFactory clientFactory; - private final Ticket ticket; - private final Broadcast options; - private final Location location; - private boolean parallel; - - public FlightDataReader(Broadcast options) { - this.options = options; - this.location = Location.forGrpcInsecure(options.value().getHost(), options.value().getPort()); - this.ticket = new Ticket(options.value().getTicket()); - } - - private void start() { - if (allocator != null) { - return; - } - FlightDataSourceReader.FactoryOptions options = this.options.getValue(); - this.parallel = options.isParallel(); - this.allocator = new RootAllocator(); - logger.warn("setting up a data reader at host {} and port {} with ticket {}", options.getHost(), options.getPort(), new String(ticket.getBytes())); - clientFactory = new FlightClientFactory(location, options.getUsername(), options.getPassword(), parallel); - client = clientFactory.apply(); - stream = client.getStream(ticket); - if (parallel) { - logger.debug("doing create action for ticket {}", new String(ticket.getBytes())); - client.doAction(new Action("create", ticket.getBytes())).forEachRemaining(Object::toString); - logger.debug("completed create action for ticket {}", new String(ticket.getBytes())); - } - } - - @Override - public boolean next() throws IOException { - start(); - try { - return stream.next(); - } catch (Throwable t) { - throw new IOException(t); - } - } - - @Override - public ColumnarBatch get() { - start(); - ColumnarBatch batch = new ColumnarBatch( - stream.getRoot().getFieldVectors() - .stream() - .map(FlightArrowColumnVector::new) - .toArray(ColumnVector[]::new) - ); - batch.setNumRows(stream.getRoot().getRowCount()); - return batch; - } - - @Override - public void close() throws IOException { - try { - if (parallel) { - client.doAction(new Action("delete", ticket.getBytes())).forEachRemaining(Object::toString); - } - AutoCloseables.close(stream, client, clientFactory, allocator); - allocator.close(); - } catch (Exception e) { - throw new IOException(e); - } - } -} diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightDataReaderFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightDataReaderFactory.java deleted file mode 100644 index 12ad028..0000000 --- a/src/main/java/org/apache/arrow/flight/spark/FlightDataReaderFactory.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (C) 2019 Ryan Murray - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.flight.spark; - -import org.apache.spark.broadcast.Broadcast; -import org.apache.spark.sql.sources.v2.reader.InputPartition; -import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; -import org.apache.spark.sql.vectorized.ColumnarBatch; - -public class FlightDataReaderFactory implements InputPartition { - - private final Broadcast options; - - public FlightDataReaderFactory(Broadcast options) { - this.options = options; - } - - @Override - public String[] preferredLocations() { - return new String[]{options.value().getHost()}; - } - - @Override - public InputPartitionReader createPartitionReader() { - return new FlightDataReader(options); - } - -} diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightDataSource.java b/src/main/java/org/apache/arrow/flight/spark/FlightDataSource.java new file mode 100644 index 0000000..57751d0 --- /dev/null +++ b/src/main/java/org/apache/arrow/flight/spark/FlightDataSource.java @@ -0,0 +1,59 @@ +package org.apache.arrow.flight.spark; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.util.Map; +import java.util.Optional; + +import org.apache.arrow.flight.Location; +import org.apache.spark.sql.connector.catalog.TableProvider; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.sources.DataSourceRegister; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +public class FlightDataSource implements TableProvider, DataSourceRegister { + + private FlightTable makeTable(CaseInsensitiveStringMap options) { + String protocol = options.getOrDefault("protocol", "grpc"); + Location location; + if (protocol == "grpc+tls") { + location = Location.forGrpcTls( + options.getOrDefault("host", "localhost"), + Integer.parseInt(options.getOrDefault("port", "47470")) + ); + } else { + location = Location.forGrpcInsecure( + options.getOrDefault("host", "localhost"), + Integer.parseInt(options.getOrDefault("port", "47470")) + ); + } + + String sql = options.getOrDefault("path", ""); + String trustedCertificates = options.getOrDefault("trustedCertificates", ""); + Optional trustedCertificatesIs = trustedCertificates.isBlank() ? Optional.empty() : Optional.of(new ByteArrayInputStream(trustedCertificates.getBytes())); + + return new FlightTable( + String.format("{} Location {} Command {}", shortName(), location.getUri().toString(), sql), + location, + sql, + trustedCertificatesIs + ); + } + + @Override + public StructType inferSchema(CaseInsensitiveStringMap options) { + return makeTable(options).schema(); + } + + @Override + public String shortName() { + return "flight"; + } + + @Override + public Table getTable(StructType schema, Transform[] partitioning, Map options) { + return makeTable(new CaseInsensitiveStringMap(options)); + } +} diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightDataSourceReader.java b/src/main/java/org/apache/arrow/flight/spark/FlightDataSourceReader.java deleted file mode 100644 index 6925b36..0000000 --- a/src/main/java/org/apache/arrow/flight/spark/FlightDataSourceReader.java +++ /dev/null @@ -1,414 +0,0 @@ -/* - * Copyright (C) 2019 Ryan Murray - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.flight.spark; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.StringJoiner; -import java.util.stream.Collectors; - -import org.apache.arrow.flight.FlightClient; -import org.apache.arrow.flight.FlightDescriptor; -import org.apache.arrow.flight.FlightInfo; -import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.SchemaResult; -import org.apache.arrow.vector.types.FloatingPointPrecision; -import org.apache.arrow.vector.types.pojo.ArrowType; -import org.apache.arrow.vector.types.pojo.FieldType; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.broadcast.Broadcast; -import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.sources.EqualTo; -import org.apache.spark.sql.sources.Filter; -import org.apache.spark.sql.sources.GreaterThan; -import org.apache.spark.sql.sources.GreaterThanOrEqual; -import org.apache.spark.sql.sources.IsNotNull; -import org.apache.spark.sql.sources.LessThan; -import org.apache.spark.sql.sources.LessThanOrEqual; -import org.apache.spark.sql.sources.v2.reader.InputPartition; -import org.apache.spark.sql.sources.v2.reader.SupportsPushDownFilters; -import org.apache.spark.sql.sources.v2.reader.SupportsPushDownRequiredColumns; -import org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.vectorized.ColumnarBatch; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.google.common.base.Joiner; -import com.google.common.collect.Lists; - -import scala.collection.JavaConversions; - -public class FlightDataSourceReader implements SupportsScanColumnarBatch, SupportsPushDownFilters, SupportsPushDownRequiredColumns, AutoCloseable { - private static final Logger LOGGER = LoggerFactory.getLogger(FlightDataSourceReader.class); - private static final Joiner WHERE_JOINER = Joiner.on(" and "); - private static final Joiner PROJ_JOINER = Joiner.on(", "); - private final Location defaultLocation; - private SchemaResult info; - private FlightDescriptor descriptor; - private StructType schema; - private final FlightClientFactory clientFactory; - private String sql; - private final Broadcast dataSourceOptions; - private Filter[] pushed; - private SparkSession lazySpark; - private JavaSparkContext lazySparkContext; - - public FlightDataSourceReader(Broadcast dataSourceOptions) { - clientFactory = new FlightClientFactory( - dataSourceOptions.value().getLocation(), - dataSourceOptions.value().getUsername(), - dataSourceOptions.value().getPassword(), - dataSourceOptions.value().isParallel() - ); - defaultLocation = dataSourceOptions.value().getLocation(); - sql = dataSourceOptions.value().getSql(); - this.dataSourceOptions = dataSourceOptions; - descriptor = getDescriptor(sql); - try (FlightClient client = clientFactory.apply()) { - info = client.getSchema(descriptor); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } - - private FlightDescriptor getDescriptor(String path) { - return FlightDescriptor.command(path.getBytes()); - } - - private StructType readSchemaImpl() { - StructField[] fields = info.getSchema().getFields().stream() - .map(field -> - new StructField(field.getName(), - sparkFromArrow(field.getFieldType()), - field.isNullable(), - Metadata.empty())) - .toArray(StructField[]::new); - return new StructType(fields); - } - - public StructType readSchema() { - if (schema == null) { - schema = readSchemaImpl(); - } - return schema; - } - - private DataType sparkFromArrow(FieldType fieldType) { - switch (fieldType.getType().getTypeID()) { - case Null: - return DataTypes.NullType; - case Struct: - throw new UnsupportedOperationException("have not implemented Struct type yet"); - case List: - throw new UnsupportedOperationException("have not implemented List type yet"); - case FixedSizeList: - throw new UnsupportedOperationException("have not implemented FixedSizeList type yet"); - case Union: - throw new UnsupportedOperationException("have not implemented Union type yet"); - case Int: - ArrowType.Int intType = (ArrowType.Int) fieldType.getType(); - int bitWidth = intType.getBitWidth(); - if (bitWidth == 8) { - return DataTypes.ByteType; - } else if (bitWidth == 16) { - return DataTypes.ShortType; - } else if (bitWidth == 32) { - return DataTypes.IntegerType; - } else if (bitWidth == 64) { - return DataTypes.LongType; - } - throw new UnsupportedOperationException("unknown int type with bitwidth " + bitWidth); - case FloatingPoint: - ArrowType.FloatingPoint floatType = (ArrowType.FloatingPoint) fieldType.getType(); - FloatingPointPrecision precision = floatType.getPrecision(); - switch (precision) { - case HALF: - case SINGLE: - return DataTypes.FloatType; - case DOUBLE: - return DataTypes.DoubleType; - } - case Utf8: - return DataTypes.StringType; - case Binary: - case FixedSizeBinary: - return DataTypes.BinaryType; - case Bool: - return DataTypes.BooleanType; - case Decimal: - throw new UnsupportedOperationException("have not implemented Decimal type yet"); - case Date: - return DataTypes.DateType; - case Time: - return DataTypes.TimestampType; //note i don't know what this will do! - case Timestamp: - return DataTypes.TimestampType; - case Interval: - return DataTypes.CalendarIntervalType; - case NONE: - return DataTypes.NullType; - } - throw new IllegalStateException("Unexpected value: " + fieldType); - } - - @Override - public List> planBatchInputPartitions() { - return planBatchInputPartitionsParallel(); - } - - private List> planBatchInputPartitionsParallel() { - - try (FlightClient client = clientFactory.apply()) { - FlightInfo info = client.getInfo(FlightDescriptor.command(sql.getBytes())); - return planBatchInputPartitionsSerial(info); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } - - private List> planBatchInputPartitionsSerial(FlightInfo info) { - LOGGER.warn("planning partitions for endpoints {}", Joiner.on(", ").join(info.getEndpoints().stream().map(e -> e.getLocations().get(0).getUri().toString()).collect(Collectors.toList()))); - List> batches = info.getEndpoints().stream().map(endpoint -> { - Location location = (endpoint.getLocations().isEmpty()) ? - Location.forGrpcInsecure(defaultLocation.getUri().getHost(), defaultLocation.getUri().getPort()) : - endpoint.getLocations().get(0); - FactoryOptions options = dataSourceOptions.value().copy(location, endpoint.getTicket().getBytes()); - LOGGER.warn("X1 {}", dataSourceOptions.value()); - return new FlightDataReaderFactory(lazySparkContext().broadcast(options)); - }).collect(Collectors.toList()); - LOGGER.info("Created {} batches from arrow endpoints", batches.size()); - return batches; - } - - private SparkSession lazySparkSession() { - if (lazySpark == null) { - this.lazySpark = SparkSession.builder().getOrCreate(); - } - return lazySpark; - } - - private JavaSparkContext lazySparkContext() { - if (lazySparkContext == null) { - this.lazySparkContext = new JavaSparkContext(lazySparkSession().sparkContext()); - } - return lazySparkContext; - } - - static class FactoryOptions implements Serializable { - private final String host; - private final int port; - private final String sql; - private final String username; - private final String password; - private final boolean parallel; - private final byte[] ticket; - - FactoryOptions(Location location, String sql, String username, String password, boolean parallel, byte[] ticket) { - this.host = location.getUri().getHost(); - this.port = location.getUri().getPort(); - this.sql = sql; - this.username = username; - this.password = password; - this.parallel = parallel; - this.ticket = ticket; - } - - public String getUsername() { - return username; - } - - public String getPassword() { - return password; - } - - public boolean isParallel() { - return parallel; - } - - public Location getLocation() { - return Location.forGrpcInsecure(host, port); - } - - public String getHost() { - return host; - } - - public int getPort() { - return port; - } - - public String getSql() { - return sql; - } - - @Override - public String toString() { - return new StringJoiner(", ", FactoryOptions.class.getSimpleName() + "[", "]") - .add("host='" + host + "'") - .add("port=" + port) - .add("sql='" + sql + "'") - .add("username='" + username + "'") - .add("password='" + password + "'") - .add("parallel=" + parallel) - .add("ticket=" + Arrays.toString(ticket)) - .toString(); - } - - public byte[] getTicket() { - return ticket; - } - - FactoryOptions copy(Location location, byte[] ticket) { - return new FactoryOptions( - location, - sql, - username, - password, - parallel, - ticket); - } - } - - @Override - public Filter[] pushFilters(Filter[] filters) { - List notPushed = Lists.newArrayList(); - List pushed = Lists.newArrayList(); - for (Filter filter : filters) { - boolean isPushed = canBePushed(filter); - if (isPushed) { - pushed.add(filter); - } else { - notPushed.add(filter); - } - } - this.pushed = pushed.toArray(new Filter[0]); - if (!pushed.isEmpty()) { - String whereClause = generateWhereClause(pushed); - mergeWhereDescriptors(whereClause); - try (FlightClient client = clientFactory.apply()) { - info = client.getSchema(descriptor); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } - return notPushed.toArray(new Filter[0]); - } - - private void mergeWhereDescriptors(String whereClause) { - sql = String.format("select * from (%s) where %s", sql, whereClause); - descriptor = getDescriptor(sql); - } - - private void mergeProjDescriptors(String projClause) { - sql = String.format("select %s from (%s)", projClause, sql); - descriptor = getDescriptor(sql); - } - - private String generateWhereClause(List pushed) { - List filterStr = Lists.newArrayList(); - for (Filter filter : pushed) { - if (filter instanceof IsNotNull) { - filterStr.add(String.format("isnotnull(\"%s\")", ((IsNotNull) filter).attribute())); - } else if (filter instanceof EqualTo) { - filterStr.add(String.format("\"%s\" = %s", ((EqualTo) filter).attribute(), valueToString(((EqualTo) filter).value()))); - } else if (filter instanceof GreaterThan) { - filterStr.add(String.format("\"%s\" > %s", ((GreaterThan) filter).attribute(), valueToString(((GreaterThan) filter).value()))); - } else if (filter instanceof GreaterThanOrEqual) { - filterStr.add(String.format("\"%s\" <= %s", ((GreaterThanOrEqual) filter).attribute(), valueToString(((GreaterThanOrEqual) filter).value()))); - } else if (filter instanceof LessThan) { - filterStr.add(String.format("\"%s\" < %s", ((LessThan) filter).attribute(), valueToString(((LessThan) filter).value()))); - } else if (filter instanceof LessThanOrEqual) { - filterStr.add(String.format("\"%s\" <= %s", ((LessThanOrEqual) filter).attribute(), valueToString(((LessThanOrEqual) filter).value()))); - } - //todo fill out rest of Filter types - } - return WHERE_JOINER.join(filterStr); - } - - private String valueToString(Object value) { - if (value instanceof String) { - return String.format("'%s'", value); - } - return value.toString(); - } - - private boolean canBePushed(Filter filter) { - if (filter instanceof IsNotNull) { - return true; - } else if (filter instanceof EqualTo) { - return true; - } - if (filter instanceof GreaterThan) { - return true; - } - if (filter instanceof GreaterThanOrEqual) { - return true; - } - if (filter instanceof LessThan) { - return true; - } - if (filter instanceof LessThanOrEqual) { - return true; - } - LOGGER.error("Cant push filter of type " + filter.toString()); - return false; - } - - @Override - public Filter[] pushedFilters() { - return pushed; - } - - @Override - public void pruneColumns(StructType requiredSchema) { - if (requiredSchema.toSeq().isEmpty()) { - return; - } - StructType schema = readSchema(); - List fields = Lists.newArrayList(); - List fieldsLeft = Lists.newArrayList(); - Map fieldNames = JavaConversions.seqAsJavaList(schema.toSeq()).stream().collect(Collectors.toMap(StructField::name, f -> f)); - for (StructField field : JavaConversions.seqAsJavaList(requiredSchema.toSeq())) { - String name = field.name(); - StructField f = fieldNames.remove(name); - if (f != null) { - fields.add(String.format("\"%s\"", name)); - fieldsLeft.add(f); - } - } - if (!fieldNames.isEmpty()) { - this.schema = new StructType(fieldsLeft.toArray(new StructField[0])); - mergeProjDescriptors(PROJ_JOINER.join(fields)); - try (FlightClient client = clientFactory.apply()) { - info = client.getSchema(descriptor); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } - } - - @Override - public void close() { - clientFactory.close(); - } -} diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightPartition.java b/src/main/java/org/apache/arrow/flight/spark/FlightPartition.java new file mode 100644 index 0000000..8df21d7 --- /dev/null +++ b/src/main/java/org/apache/arrow/flight/spark/FlightPartition.java @@ -0,0 +1,21 @@ +package org.apache.arrow.flight.spark; + +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.spark.sql.connector.read.InputPartition; + +public class FlightPartition implements InputPartition { + private final FlightEndpoint endpoint; + + public FlightPartition(FlightEndpoint endpoint) { + this.endpoint = endpoint; + } + + @Override + public String[] preferredLocations() { + return endpoint.getLocations().stream().map(location -> location.getUri().getHost()).toArray(String[]::new); + } + + public FlightEndpoint getEndpoint() { + return endpoint; + } +} diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReader.java b/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReader.java new file mode 100644 index 0000000..e2028bc --- /dev/null +++ b/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReader.java @@ -0,0 +1,102 @@ +package org.apache.arrow.flight.spark; + +import java.io.IOException; +import java.util.Iterator; +import java.util.Optional; + +import org.apache.arrow.flight.FlightStream; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +public class FlightPartitionReader implements PartitionReader { + private final FlightStream stream; + private Optional> batch; + private InternalRow row; + + public FlightPartitionReader(FlightStream stream) { + this.stream = stream; + } + + @Override + public void close() throws IOException { + try { + stream.close(); + } catch (Exception e) { + throw new IOException(e); + } + } + + private Iterator getNextBatch() { + ColumnarBatch batch = new ColumnarBatch( + stream.getRoot().getFieldVectors() + .stream() + .map(FlightArrowColumnVector::new) + .toArray(ColumnVector[]::new) + ); + batch.setNumRows(stream.getRoot().getRowCount()); + return batch.rowIterator(); + } + + // This is written this way because the Spark interface iterates in a different way. + // E.g., .next() -> .get() vs. .hasNext() -> .next() + @Override + public boolean next() throws IOException { + try { + // Try the iterator first then get next batch + // Not quite rust match expressions... + return batch.map(currentBatch -> { + // Are there still rows in this batch? + if (currentBatch.hasNext()) { + row = currentBatch.next(); + return true; + // No more rows, get the next batch + } else { + // Is there another batch? + if (stream.next()) { + // Yes, then fetch it. + Iterator nextBatch = getNextBatch(); + batch = Optional.of(nextBatch); + if (currentBatch.hasNext()) { + row = currentBatch.next(); + return true; + // Odd, we got an empty batch + } else { + return false; + } + // This partition / stream is complete + } else { + return false; + } + } + // Fetch the first batch + }).orElseGet(() -> { + // Is the stream empty? + if (stream.next()) { + // No, then fetch the first batch + Iterator firstBatch = getNextBatch(); + batch = Optional.of(firstBatch); + if (firstBatch.hasNext()) { + row = firstBatch.next(); + return true; + // Odd, we got an empty batch + } else { + return false; + } + // The stream was empty... + } else { + return false; + } + }); + } catch (RuntimeException e) { + throw new IOException(e); + } + } + + @Override + public InternalRow get() { + return row; + } + +} diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java new file mode 100644 index 0000000..d442464 --- /dev/null +++ b/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java @@ -0,0 +1,46 @@ +package org.apache.arrow.flight.spark; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightStream; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class FlightPartitionReaderFactory implements PartitionReaderFactory { + private static final Logger logger = LoggerFactory.getLogger(FlightPartitionReaderFactory.class); + private final FlightClientFactory clientFactory; + + public FlightPartitionReaderFactory(FlightClientFactory clientFactory) { + this.clientFactory = clientFactory; + } + + private FlightStream createStream(InputPartition iPartition) { + // This feels wrong but this is what upstream spark sources do to. + FlightPartition partition = (FlightPartition) iPartition; + logger.info("Reading Flight data from locations: {}", (Object) partition.preferredLocations()); + FlightClient client = clientFactory.apply(); + return client.getStream(partition.getEndpoint().getTicket()); + } + + @Override + public PartitionReader createReader(InputPartition partition) { + FlightStream stream = createStream(partition); + return new FlightPartitionReader(stream); + } + + @Override + public PartitionReader createColumnarReader(InputPartition partition) { + FlightStream stream = createStream(partition); + return new FlightColumnarPartitionReader(stream); + } + + @Override + public boolean supportColumnarReads(InputPartition partition) { + return true; + } + +} diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightScan.java b/src/main/java/org/apache/arrow/flight/spark/FlightScan.java new file mode 100644 index 0000000..14ac6c3 --- /dev/null +++ b/src/main/java/org/apache/arrow/flight/spark/FlightScan.java @@ -0,0 +1,37 @@ +package org.apache.arrow.flight.spark; + +import org.apache.spark.sql.connector.read.Scan; +import org.apache.arrow.flight.FlightInfo; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.types.StructType; + +public class FlightScan implements Scan, Batch { + private final StructType schema; + private final FlightInfo info; + public FlightScan(StructType schema, FlightInfo info) { + this.schema = schema; + this.info = info; + } + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public InputPartition[] planInputPartitions() { + InputPartition[] batches = info.getEndpoints().stream().map(endpoint -> { + return new FlightPartition(endpoint); + }).toArray(InputPartition[]::new); + return batches; + } + + @Override + public PartitionReaderFactory createReaderFactory() { + // TODO Auto-generated method stub + return null; + } + +} diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java b/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java new file mode 100644 index 0000000..7f15e75 --- /dev/null +++ b/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java @@ -0,0 +1,275 @@ +// Portions of this file where taken from: +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.spark; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.read.SupportsPushDownFilters; +import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; +import org.apache.spark.sql.sources.*; +import org.apache.spark.sql.types.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import scala.collection.JavaConversions; + +import com.google.common.collect.Lists; +import com.google.common.base.Joiner; + +public class FlightScanBuilder implements ScanBuilder, SupportsPushDownRequiredColumns, SupportsPushDownFilters { + private static final Logger LOGGER = LoggerFactory.getLogger(FlightScanBuilder.class); + private static final Joiner WHERE_JOINER = Joiner.on(" and "); + private static final Joiner PROJ_JOINER = Joiner.on(", "); + private SchemaResult info; + private StructType schema; + private final FlightClientFactory clientFactory; + private FlightDescriptor descriptor; + private String sql; + private Filter[] pushed; + + public FlightScanBuilder(FlightClientFactory clientFactory, String sql) { + this.clientFactory = clientFactory; + this.sql = sql; + descriptor = getDescriptor(sql); + try (FlightClient client = clientFactory.apply()) { + info = client.getSchema(descriptor); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Override + public Scan build() { + try (FlightClient client = clientFactory.apply()) { + FlightInfo info = client.getInfo(FlightDescriptor.command(sql.getBytes())); + return new FlightScan(readSchema(), info); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + private boolean canBePushed(Filter filter) { + if (filter instanceof IsNotNull) { + return true; + } else if (filter instanceof EqualTo) { + return true; + } + if (filter instanceof GreaterThan) { + return true; + } + if (filter instanceof GreaterThanOrEqual) { + return true; + } + if (filter instanceof LessThan) { + return true; + } + if (filter instanceof LessThanOrEqual) { + return true; + } + LOGGER.error("Cant push filter of type " + filter.toString()); + return false; + } + + private String valueToString(Object value) { + if (value instanceof String) { + return String.format("'%s'", value); + } + return value.toString(); + } + + private String generateWhereClause(List pushed) { + List filterStr = Lists.newArrayList(); + for (Filter filter : pushed) { + if (filter instanceof IsNotNull) { + filterStr.add(String.format("isnotnull(\"%s\")", ((IsNotNull) filter).attribute())); + } else if (filter instanceof EqualTo) { + filterStr.add(String.format("\"%s\" = %s", ((EqualTo) filter).attribute(), valueToString(((EqualTo) filter).value()))); + } else if (filter instanceof GreaterThan) { + filterStr.add(String.format("\"%s\" > %s", ((GreaterThan) filter).attribute(), valueToString(((GreaterThan) filter).value()))); + } else if (filter instanceof GreaterThanOrEqual) { + filterStr.add(String.format("\"%s\" <= %s", ((GreaterThanOrEqual) filter).attribute(), valueToString(((GreaterThanOrEqual) filter).value()))); + } else if (filter instanceof LessThan) { + filterStr.add(String.format("\"%s\" < %s", ((LessThan) filter).attribute(), valueToString(((LessThan) filter).value()))); + } else if (filter instanceof LessThanOrEqual) { + filterStr.add(String.format("\"%s\" <= %s", ((LessThanOrEqual) filter).attribute(), valueToString(((LessThanOrEqual) filter).value()))); + } + //todo fill out rest of Filter types + } + return WHERE_JOINER.join(filterStr); + } + + private FlightDescriptor getDescriptor(String sql) { + return FlightDescriptor.command(sql.getBytes()); + } + + private void mergeWhereDescriptors(String whereClause) { + sql = String.format("select * from (%s) where %s", sql, whereClause); + descriptor = getDescriptor(sql); + } + + @Override + public Filter[] pushFilters(Filter[] filters) { + List notPushed = Lists.newArrayList(); + List pushed = Lists.newArrayList(); + for (Filter filter : filters) { + boolean isPushed = canBePushed(filter); + if (isPushed) { + pushed.add(filter); + } else { + notPushed.add(filter); + } + } + this.pushed = pushed.toArray(new Filter[0]); + if (!pushed.isEmpty()) { + String whereClause = generateWhereClause(pushed); + mergeWhereDescriptors(whereClause); + try (FlightClient client = clientFactory.apply()) { + info = client.getSchema(descriptor); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + return notPushed.toArray(new Filter[0]); + } + + @Override + public Filter[] pushedFilters() { + return pushed; + } + + private DataType sparkFromArrow(FieldType fieldType) { + switch (fieldType.getType().getTypeID()) { + case Null: + return DataTypes.NullType; + case Struct: + throw new UnsupportedOperationException("have not implemented Struct type yet"); + case List: + throw new UnsupportedOperationException("have not implemented List type yet"); + case FixedSizeList: + throw new UnsupportedOperationException("have not implemented FixedSizeList type yet"); + case Union: + throw new UnsupportedOperationException("have not implemented Union type yet"); + case Int: + ArrowType.Int intType = (ArrowType.Int) fieldType.getType(); + int bitWidth = intType.getBitWidth(); + if (bitWidth == 8) { + return DataTypes.ByteType; + } else if (bitWidth == 16) { + return DataTypes.ShortType; + } else if (bitWidth == 32) { + return DataTypes.IntegerType; + } else if (bitWidth == 64) { + return DataTypes.LongType; + } + throw new UnsupportedOperationException("unknown int type with bitwidth " + bitWidth); + case FloatingPoint: + ArrowType.FloatingPoint floatType = (ArrowType.FloatingPoint) fieldType.getType(); + FloatingPointPrecision precision = floatType.getPrecision(); + switch (precision) { + case HALF: + case SINGLE: + return DataTypes.FloatType; + case DOUBLE: + return DataTypes.DoubleType; + } + case Utf8: + return DataTypes.StringType; + case Binary: + case FixedSizeBinary: + return DataTypes.BinaryType; + case Bool: + return DataTypes.BooleanType; + case Decimal: + throw new UnsupportedOperationException("have not implemented Decimal type yet"); + case Date: + return DataTypes.DateType; + case Time: + return DataTypes.TimestampType; // note i don't know what this will do! + case Timestamp: + return DataTypes.TimestampType; + case Interval: + return DataTypes.CalendarIntervalType; + case NONE: + return DataTypes.NullType; + default: + throw new IllegalStateException("Unexpected value: " + fieldType); + } + } + + private StructType readSchemaImpl() { + StructField[] fields = info.getSchema().getFields().stream() + .map(field -> new StructField(field.getName(), + sparkFromArrow(field.getFieldType()), + field.isNullable(), + Metadata.empty())) + .toArray(StructField[]::new); + return new StructType(fields); + } + + public StructType readSchema() { + if (schema == null) { + schema = readSchemaImpl(); + } + return schema; + } + + private void mergeProjDescriptors(String projClause) { + sql = String.format("select %s from (%s)", projClause, sql); + descriptor = getDescriptor(sql); + } + + @Override + public void pruneColumns(StructType requiredSchema) { + if (requiredSchema.toSeq().isEmpty()) { + return; + } + StructType schema = readSchema(); + List fields = Lists.newArrayList(); + List fieldsLeft = Lists.newArrayList(); + Map fieldNames = JavaConversions.seqAsJavaList(schema.toSeq()).stream() + .collect(Collectors.toMap(StructField::name, f -> f)); + for (StructField field : JavaConversions.seqAsJavaList(requiredSchema.toSeq())) { + String name = field.name(); + StructField f = fieldNames.remove(name); + if (f != null) { + fields.add(String.format("\"%s\"", name)); + fieldsLeft.add(f); + } + } + if (!fieldNames.isEmpty()) { + this.schema = new StructType(fieldsLeft.toArray(new StructField[0])); + mergeProjDescriptors(PROJ_JOINER.join(fields)); + try (FlightClient client = clientFactory.apply()) { + info = client.getSchema(descriptor); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } +} diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightTable.java b/src/main/java/org/apache/arrow/flight/spark/FlightTable.java new file mode 100644 index 0000000..0fb3ae7 --- /dev/null +++ b/src/main/java/org/apache/arrow/flight/spark/FlightTable.java @@ -0,0 +1,53 @@ +package org.apache.arrow.flight.spark; + +import java.io.InputStream; +import java.util.Optional; +import java.util.Set; + +import org.apache.arrow.flight.Location; +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +public class FlightTable implements Table, SupportsRead { + private static final Set CAPABILITIES = Set.of(TableCapability.BATCH_READ); + private final String name; + private final FlightClientFactory clientFactory; + private final String sql; + private StructType schema; + + public FlightTable(String name, Location location, String sql, Optional trustedCertificates) { + this.name = name; + clientFactory = new FlightClientFactory(location, trustedCertificates); + this.sql = sql; + } + + @Override + public String name() { + return name; + } + + @Override + public StructType schema() { + if (schema == null) { + schema = (new FlightScanBuilder(clientFactory, sql)).readSchema(); + } + return schema; + } + + // TODO - We could probably implement partitioning() but it would require server side support + + @Override + public Set capabilities() { + // We only support reading for now + return CAPABILITIES; + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + return new FlightScanBuilder(clientFactory, sql); + } +} diff --git a/src/main/scala/org/apache/spark/sql/execution/arrow/FlightArrowUtils.scala b/src/main/scala/org/apache/spark/sql/execution/arrow/FlightArrowUtils.scala index d8c210e..a4960e7 100644 --- a/src/main/scala/org/apache/spark/sql/execution/arrow/FlightArrowUtils.scala +++ b/src/main/scala/org/apache/spark/sql/execution/arrow/FlightArrowUtils.scala @@ -126,11 +126,8 @@ object FlightArrowUtils { /** Return Map with conf settings to be used in ArrowPythonRunner */ def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { - val timeZoneConf = if (conf.pandasRespectSessionTimeZone) { - Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) - } else { - Nil - } + val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> + conf.sessionLocalTimeZone) val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key -> conf.pandasGroupedMapAssignColumnsByName.toString) Map(timeZoneConf ++ pandasColsByName: _*) From 8a2aaca071aa64e5b16a939a8afda556cc60cb4c Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Wed, 13 Apr 2022 17:09:24 -0400 Subject: [PATCH 20/38] Getting futher in tests. Need to fix unserializable tasks. --- pom.xml | 2 +- .../{FlightDataSource.java => DefaultSource.java} | 2 +- .../org/apache/arrow/flight/spark/FlightScan.java | 12 +++++++++--- .../apache/arrow/flight/spark/FlightScanBuilder.java | 2 +- 4 files changed, 12 insertions(+), 6 deletions(-) rename src/main/java/org/apache/arrow/flight/spark/{FlightDataSource.java => DefaultSource.java} (96%) diff --git a/pom.xml b/pom.xml index ddb285c..27e6e99 100644 --- a/pom.xml +++ b/pom.xml @@ -31,7 +31,7 @@ 7.0.0 3.2.1 1.7.25 - 2.4.4 + 2.12.3 diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightDataSource.java b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java similarity index 96% rename from src/main/java/org/apache/arrow/flight/spark/FlightDataSource.java rename to src/main/java/org/apache/arrow/flight/spark/DefaultSource.java index 57751d0..c499be2 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightDataSource.java +++ b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java @@ -13,7 +13,7 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -public class FlightDataSource implements TableProvider, DataSourceRegister { +public class DefaultSource implements TableProvider, DataSourceRegister { private FlightTable makeTable(CaseInsensitiveStringMap options) { String protocol = options.getOrDefault("protocol", "grpc"); diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightScan.java b/src/main/java/org/apache/arrow/flight/spark/FlightScan.java index 14ac6c3..037ee62 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightScan.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightScan.java @@ -9,9 +9,11 @@ public class FlightScan implements Scan, Batch { private final StructType schema; + private final FlightClientFactory clientFactory; private final FlightInfo info; - public FlightScan(StructType schema, FlightInfo info) { + public FlightScan(StructType schema, FlightClientFactory clientFactory, FlightInfo info) { this.schema = schema; + this.clientFactory = clientFactory; this.info = info; } @@ -20,6 +22,11 @@ public StructType readSchema() { return schema; } + @Override + public Batch toBatch() { + return this; + } + @Override public InputPartition[] planInputPartitions() { InputPartition[] batches = info.getEndpoints().stream().map(endpoint -> { @@ -30,8 +37,7 @@ public InputPartition[] planInputPartitions() { @Override public PartitionReaderFactory createReaderFactory() { - // TODO Auto-generated method stub - return null; + return new FlightPartitionReaderFactory(clientFactory); } } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java b/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java index 7f15e75..109f2d3 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java @@ -68,7 +68,7 @@ public FlightScanBuilder(FlightClientFactory clientFactory, String sql) { public Scan build() { try (FlightClient client = clientFactory.apply()) { FlightInfo info = client.getInfo(FlightDescriptor.command(sql.getBytes())); - return new FlightScan(readSchema(), info); + return new FlightScan(readSchema(), this.clientFactory, info); } catch (InterruptedException e) { throw new RuntimeException(e); } From 13287d81b6da271d7778ce8072332984cecf0c9b Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Thu, 14 Apr 2022 21:38:08 -0400 Subject: [PATCH 21/38] Further still need to debug serialization issues. --- .../arrow/flight/spark/DefaultSource.java | 8 +-- .../flight/spark/FlightClientFactory.java | 14 +++-- .../flight/spark/FlightClientOptions.java | 13 +++++ .../spark/FlightPartitionReaderFactory.java | 21 +++++-- .../apache/arrow/flight/spark/FlightScan.java | 10 ++-- .../arrow/flight/spark/FlightScanBuilder.java | 57 +++++++++++-------- .../arrow/flight/spark/FlightTable.java | 14 ++--- 7 files changed, 87 insertions(+), 50 deletions(-) create mode 100644 src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java diff --git a/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java index c499be2..b175d98 100644 --- a/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java +++ b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java @@ -1,9 +1,6 @@ package org.apache.arrow.flight.spark; -import java.io.ByteArrayInputStream; -import java.io.InputStream; import java.util.Map; -import java.util.Optional; import org.apache.arrow.flight.Location; import org.apache.spark.sql.connector.catalog.TableProvider; @@ -32,13 +29,14 @@ private FlightTable makeTable(CaseInsensitiveStringMap options) { String sql = options.getOrDefault("path", ""); String trustedCertificates = options.getOrDefault("trustedCertificates", ""); - Optional trustedCertificatesIs = trustedCertificates.isBlank() ? Optional.empty() : Optional.of(new ByteArrayInputStream(trustedCertificates.getBytes())); + + FlightClientOptions clientOptions = trustedCertificates.isEmpty() ? null : new FlightClientOptions(trustedCertificates); return new FlightTable( String.format("{} Location {} Command {}", shortName(), location.getUri().toString(), sql), location, sql, - trustedCertificatesIs + clientOptions ); } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java index e577ed5..835f050 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java @@ -15,8 +15,8 @@ */ package org.apache.arrow.flight.spark; +import java.io.ByteArrayInputStream; import java.io.InputStream; -import java.util.Optional; import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.Location; @@ -26,17 +26,19 @@ public class FlightClientFactory implements AutoCloseable { private final BufferAllocator allocator = new RootAllocator(); private final Location defaultLocation; - private final Optional trustedCertificates; + private InputStream trustedCertificates; - public FlightClientFactory(Location defaultLocation, Optional trustedCertificates) { + public FlightClientFactory(Location defaultLocation, FlightClientOptions clientOptions) { this.defaultLocation = defaultLocation; - this.trustedCertificates = trustedCertificates; + if (clientOptions != null) { + this.trustedCertificates = new ByteArrayInputStream(clientOptions.getTrustedCertificates().getBytes()); + } } public FlightClient apply() { FlightClient.Builder builder = FlightClient.builder(allocator, defaultLocation); - if (trustedCertificates.isPresent()) { - builder.trustedCertificates(trustedCertificates.get()); + if (trustedCertificates != null) { + builder.trustedCertificates(trustedCertificates); } return builder.build(); } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java b/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java new file mode 100644 index 0000000..a9ff945 --- /dev/null +++ b/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java @@ -0,0 +1,13 @@ +package org.apache.arrow.flight.spark; + +public class FlightClientOptions { + private final String trustedCertificates; + + public FlightClientOptions(String trustedCertificates) { + this.trustedCertificates = trustedCertificates; + } + + public String getTrustedCertificates() { + return trustedCertificates; + } +} diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java index d442464..924735f 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java @@ -12,18 +12,29 @@ public class FlightPartitionReaderFactory implements PartitionReaderFactory { private static final Logger logger = LoggerFactory.getLogger(FlightPartitionReaderFactory.class); - private final FlightClientFactory clientFactory; + private final FlightClientOptions clientOptions; - public FlightPartitionReaderFactory(FlightClientFactory clientFactory) { - this.clientFactory = clientFactory; + public FlightPartitionReaderFactory(FlightClientOptions clientOptions) { + this.clientOptions = clientOptions; } private FlightStream createStream(InputPartition iPartition) { // This feels wrong but this is what upstream spark sources do to. FlightPartition partition = (FlightPartition) iPartition; logger.info("Reading Flight data from locations: {}", (Object) partition.preferredLocations()); - FlightClient client = clientFactory.apply(); - return client.getStream(partition.getEndpoint().getTicket()); + // TODO - Should we handle multiple locations? + try ( + FlightClientFactory clientFactory = new FlightClientFactory( + partition.getEndpoint().getLocations().get(0), + clientOptions + ); + ) { + try (FlightClient client = clientFactory.apply()) { + return client.getStream(partition.getEndpoint().getTicket()); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } } @Override diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightScan.java b/src/main/java/org/apache/arrow/flight/spark/FlightScan.java index 037ee62..7899951 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightScan.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightScan.java @@ -1,6 +1,7 @@ package org.apache.arrow.flight.spark; import org.apache.spark.sql.connector.read.Scan; + import org.apache.arrow.flight.FlightInfo; import org.apache.spark.sql.connector.read.Batch; import org.apache.spark.sql.connector.read.InputPartition; @@ -9,12 +10,13 @@ public class FlightScan implements Scan, Batch { private final StructType schema; - private final FlightClientFactory clientFactory; private final FlightInfo info; - public FlightScan(StructType schema, FlightClientFactory clientFactory, FlightInfo info) { + private final FlightClientOptions clientOptions; + + public FlightScan(StructType schema, FlightInfo info, FlightClientOptions clientOptions) { this.schema = schema; - this.clientFactory = clientFactory; this.info = info; + this.clientOptions = clientOptions; } @Override @@ -37,7 +39,7 @@ public InputPartition[] planInputPartitions() { @Override public PartitionReaderFactory createReaderFactory() { - return new FlightPartitionReaderFactory(clientFactory); + return new FlightPartitionReaderFactory(clientOptions); } } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java b/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java index 109f2d3..256cb16 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java @@ -24,6 +24,7 @@ import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.Location; import org.apache.arrow.flight.SchemaResult; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -48,29 +49,35 @@ public class FlightScanBuilder implements ScanBuilder, SupportsPushDownRequiredC private static final Joiner PROJ_JOINER = Joiner.on(", "); private SchemaResult info; private StructType schema; - private final FlightClientFactory clientFactory; + private final Location location; + private final FlightClientOptions clientOptions; private FlightDescriptor descriptor; private String sql; private Filter[] pushed; - public FlightScanBuilder(FlightClientFactory clientFactory, String sql) { - this.clientFactory = clientFactory; + public FlightScanBuilder(Location location, FlightClientOptions clientOptions, String sql) { + this.location = location; + this.clientOptions = clientOptions; this.sql = sql; descriptor = getDescriptor(sql); - try (FlightClient client = clientFactory.apply()) { - info = client.getSchema(descriptor); - } catch (InterruptedException e) { - throw new RuntimeException(e); + try (FlightClientFactory clientFactory = new FlightClientFactory(location, clientOptions)) { + try (FlightClient client = clientFactory.apply()) { + info = client.getSchema(descriptor); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } } } @Override public Scan build() { - try (FlightClient client = clientFactory.apply()) { - FlightInfo info = client.getInfo(FlightDescriptor.command(sql.getBytes())); - return new FlightScan(readSchema(), this.clientFactory, info); - } catch (InterruptedException e) { - throw new RuntimeException(e); + try (FlightClientFactory clientFactory = new FlightClientFactory(location, clientOptions)) { + try (FlightClient client = clientFactory.apply()) { + FlightInfo info = client.getInfo(FlightDescriptor.command(sql.getBytes())); + return new FlightScan(readSchema(), info, clientOptions); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } } } @@ -147,13 +154,15 @@ public Filter[] pushFilters(Filter[] filters) { } this.pushed = pushed.toArray(new Filter[0]); if (!pushed.isEmpty()) { - String whereClause = generateWhereClause(pushed); - mergeWhereDescriptors(whereClause); - try (FlightClient client = clientFactory.apply()) { - info = client.getSchema(descriptor); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } + String whereClause = generateWhereClause(pushed); + mergeWhereDescriptors(whereClause); + try (FlightClientFactory clientFactory = new FlightClientFactory(location, clientOptions)) { + try (FlightClient client = clientFactory.apply()) { + info = client.getSchema(descriptor); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } } return notPushed.toArray(new Filter[0]); } @@ -265,10 +274,12 @@ public void pruneColumns(StructType requiredSchema) { if (!fieldNames.isEmpty()) { this.schema = new StructType(fieldsLeft.toArray(new StructField[0])); mergeProjDescriptors(PROJ_JOINER.join(fields)); - try (FlightClient client = clientFactory.apply()) { - info = client.getSchema(descriptor); - } catch (InterruptedException e) { - throw new RuntimeException(e); + try (FlightClientFactory clientFactory = new FlightClientFactory(location, clientOptions)) { + try (FlightClient client = clientFactory.apply()) { + info = client.getSchema(descriptor); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } } } } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightTable.java b/src/main/java/org/apache/arrow/flight/spark/FlightTable.java index 0fb3ae7..6fd0548 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightTable.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightTable.java @@ -1,7 +1,5 @@ package org.apache.arrow.flight.spark; -import java.io.InputStream; -import java.util.Optional; import java.util.Set; import org.apache.arrow.flight.Location; @@ -15,14 +13,16 @@ public class FlightTable implements Table, SupportsRead { private static final Set CAPABILITIES = Set.of(TableCapability.BATCH_READ); private final String name; - private final FlightClientFactory clientFactory; + private final Location location; private final String sql; + private final FlightClientOptions clientOptions; private StructType schema; - public FlightTable(String name, Location location, String sql, Optional trustedCertificates) { + public FlightTable(String name, Location location, String sql, FlightClientOptions clientOptions) { this.name = name; - clientFactory = new FlightClientFactory(location, trustedCertificates); + this.location = location; this.sql = sql; + this.clientOptions = clientOptions; } @Override @@ -33,7 +33,7 @@ public String name() { @Override public StructType schema() { if (schema == null) { - schema = (new FlightScanBuilder(clientFactory, sql)).readSchema(); + schema = (new FlightScanBuilder(location, clientOptions, sql)).readSchema(); } return schema; } @@ -48,6 +48,6 @@ public Set capabilities() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { - return new FlightScanBuilder(clientFactory, sql); + return new FlightScanBuilder(location, clientOptions, sql); } } From 50768cba5e8175e4fc0fc446dbfb63596ec65b75 Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Fri, 15 Apr 2022 21:42:15 -0400 Subject: [PATCH 22/38] Unit tests passing. --- .../flight/spark/FlightClientOptions.java | 4 +- .../spark/FlightColumnarPartitionReader.java | 29 ++++++----- .../flight/spark/FlightEndpointWrapper.java | 41 ++++++++++++++++ .../arrow/flight/spark/FlightPartition.java | 9 ++-- .../flight/spark/FlightPartitionReader.java | 30 +++++++----- .../spark/FlightPartitionReaderFactory.java | 36 +++----------- .../apache/arrow/flight/spark/FlightScan.java | 3 +- .../arrow/flight/spark/FlightScanBuilder.java | 49 +++++++------------ .../arrow/flight/spark/FlightTable.java | 6 ++- .../arrow/flight/spark/TestConnector.java | 43 +++++++++++----- 10 files changed, 146 insertions(+), 104 deletions(-) create mode 100644 src/main/java/org/apache/arrow/flight/spark/FlightEndpointWrapper.java diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java b/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java index a9ff945..11cec47 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java @@ -1,6 +1,8 @@ package org.apache.arrow.flight.spark; -public class FlightClientOptions { +import java.io.Serializable; + +public class FlightClientOptions implements Serializable { private final String trustedCertificates; public FlightClientOptions(String trustedCertificates) { diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightColumnarPartitionReader.java b/src/main/java/org/apache/arrow/flight/spark/FlightColumnarPartitionReader.java index 0a90a9e..a23aa23 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightColumnarPartitionReader.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightColumnarPartitionReader.java @@ -4,23 +4,21 @@ import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.util.AutoCloseables; import org.apache.spark.sql.vectorized.ColumnVector; public class FlightColumnarPartitionReader implements PartitionReader { + private final FlightClientFactory clientFactory;; + private final FlightClient client; private final FlightStream stream; - public FlightColumnarPartitionReader(FlightStream stream) { - this.stream = stream; - } - - @Override - public void close() throws IOException { - try { - stream.close(); - } catch (Exception e) { - throw new IOException(e); - } + public FlightColumnarPartitionReader(FlightClientOptions clientOptions, FlightPartition partition) { + // TODO - Should we handle multiple locations? + clientFactory = new FlightClientFactory(partition.getEndpoint().get().getLocations().get(0), clientOptions); + client = clientFactory.apply(); + stream = client.getStream(partition.getEndpoint().get().getTicket()); } // This is written this way because the Spark interface iterates in a different way. @@ -45,4 +43,13 @@ public ColumnarBatch get() { batch.setNumRows(stream.getRoot().getRowCount()); return batch; } + + @Override + public void close() throws IOException { + try { + AutoCloseables.close(stream, client, clientFactory); + } catch (Exception e) { + throw new IOException(e); + } + } } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightEndpointWrapper.java b/src/main/java/org/apache/arrow/flight/spark/FlightEndpointWrapper.java new file mode 100644 index 0000000..78df97a --- /dev/null +++ b/src/main/java/org/apache/arrow/flight/spark/FlightEndpointWrapper.java @@ -0,0 +1,41 @@ +package org.apache.arrow.flight.spark; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.net.URI; +import java.util.ArrayList; +import java.util.stream.Collectors; + +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Ticket; + +// This is needed for FlightEndpoint to be Serializable in spark. +// org.apache.arrow.flight.FlightEndpoint is a POJO of Serializable types. +// However if spark is using build-in serialization instead of Kyro then we must implement Serializable +public class FlightEndpointWrapper implements Serializable { + private FlightEndpoint inner; + + public FlightEndpointWrapper(FlightEndpoint inner) { + this.inner = inner; + } + + public FlightEndpoint get() { + return inner; + } + + private void writeObject(ObjectOutputStream out) throws IOException { + ArrayList locations = inner.getLocations().stream().map(location -> location.getUri()).collect(Collectors.toCollection(ArrayList::new)); + out.writeObject(locations); + out.write(inner.getTicket().getBytes()); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + @SuppressWarnings("unchecked") + Location[] locations = ((ArrayList) in.readObject()).stream().map(l -> new Location(l)).toArray(Location[]::new); + byte[] ticket = in.readAllBytes(); + this.inner = new FlightEndpoint(new Ticket(ticket), locations); + } +} diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightPartition.java b/src/main/java/org/apache/arrow/flight/spark/FlightPartition.java index 8df21d7..51c8b8c 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightPartition.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightPartition.java @@ -1,21 +1,20 @@ package org.apache.arrow.flight.spark; -import org.apache.arrow.flight.FlightEndpoint; import org.apache.spark.sql.connector.read.InputPartition; public class FlightPartition implements InputPartition { - private final FlightEndpoint endpoint; + private final FlightEndpointWrapper endpoint; - public FlightPartition(FlightEndpoint endpoint) { + public FlightPartition(FlightEndpointWrapper endpoint) { this.endpoint = endpoint; } @Override public String[] preferredLocations() { - return endpoint.getLocations().stream().map(location -> location.getUri().getHost()).toArray(String[]::new); + return endpoint.get().getLocations().stream().map(location -> location.getUri().getHost()).toArray(String[]::new); } - public FlightEndpoint getEndpoint() { + public FlightEndpointWrapper getEndpoint() { return endpoint; } } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReader.java b/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReader.java index e2028bc..70f4535 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReader.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReader.java @@ -4,28 +4,26 @@ import java.util.Iterator; import java.util.Optional; +import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.util.AutoCloseables; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; public class FlightPartitionReader implements PartitionReader { + private final FlightClientFactory clientFactory;; + private final FlightClient client; private final FlightStream stream; private Optional> batch; private InternalRow row; - public FlightPartitionReader(FlightStream stream) { - this.stream = stream; - } - - @Override - public void close() throws IOException { - try { - stream.close(); - } catch (Exception e) { - throw new IOException(e); - } + public FlightPartitionReader(FlightClientOptions clientOptions, FlightPartition partition) { + // TODO - Should we handle multiple locations? + clientFactory = new FlightClientFactory(partition.getEndpoint().get().getLocations().get(0), clientOptions); + client = clientFactory.apply(); + stream = client.getStream(partition.getEndpoint().get().getTicket()); } private Iterator getNextBatch() { @@ -98,5 +96,13 @@ public boolean next() throws IOException { public InternalRow get() { return row; } - + + @Override + public void close() throws IOException { + try { + AutoCloseables.close(stream, client, clientFactory); + } catch (Exception e) { + throw new IOException(e); + } + } } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java index 924735f..3a8f45f 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java @@ -1,52 +1,30 @@ package org.apache.arrow.flight.spark; -import org.apache.arrow.flight.FlightClient; -import org.apache.arrow.flight.FlightStream; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.connector.read.PartitionReaderFactory; import org.apache.spark.sql.vectorized.ColumnarBatch; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class FlightPartitionReaderFactory implements PartitionReaderFactory { - private static final Logger logger = LoggerFactory.getLogger(FlightPartitionReaderFactory.class); private final FlightClientOptions clientOptions; public FlightPartitionReaderFactory(FlightClientOptions clientOptions) { this.clientOptions = clientOptions; } - private FlightStream createStream(InputPartition iPartition) { + @Override + public PartitionReader createReader(InputPartition iPartition) { // This feels wrong but this is what upstream spark sources do to. FlightPartition partition = (FlightPartition) iPartition; - logger.info("Reading Flight data from locations: {}", (Object) partition.preferredLocations()); - // TODO - Should we handle multiple locations? - try ( - FlightClientFactory clientFactory = new FlightClientFactory( - partition.getEndpoint().getLocations().get(0), - clientOptions - ); - ) { - try (FlightClient client = clientFactory.apply()) { - return client.getStream(partition.getEndpoint().getTicket()); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } - } - - @Override - public PartitionReader createReader(InputPartition partition) { - FlightStream stream = createStream(partition); - return new FlightPartitionReader(stream); + return new FlightPartitionReader(clientOptions, partition); } @Override - public PartitionReader createColumnarReader(InputPartition partition) { - FlightStream stream = createStream(partition); - return new FlightColumnarPartitionReader(stream); + public PartitionReader createColumnarReader(InputPartition iPartition) { + // This feels wrong but this is what upstream spark sources do to. + FlightPartition partition = (FlightPartition) iPartition; + return new FlightColumnarPartitionReader(clientOptions, partition); } @Override diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightScan.java b/src/main/java/org/apache/arrow/flight/spark/FlightScan.java index 7899951..de84644 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightScan.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightScan.java @@ -32,7 +32,8 @@ public Batch toBatch() { @Override public InputPartition[] planInputPartitions() { InputPartition[] batches = info.getEndpoints().stream().map(endpoint -> { - return new FlightPartition(endpoint); + FlightEndpointWrapper endpointWrapper = new FlightEndpointWrapper(endpoint); + return new FlightPartition(endpointWrapper); }).toArray(InputPartition[]::new); return batches; } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java b/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java index 256cb16..6080737 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java @@ -26,6 +26,7 @@ import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.Location; import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.FieldType; @@ -43,42 +44,32 @@ import com.google.common.collect.Lists; import com.google.common.base.Joiner; -public class FlightScanBuilder implements ScanBuilder, SupportsPushDownRequiredColumns, SupportsPushDownFilters { +public class FlightScanBuilder implements ScanBuilder, SupportsPushDownRequiredColumns, SupportsPushDownFilters, AutoCloseable { private static final Logger LOGGER = LoggerFactory.getLogger(FlightScanBuilder.class); private static final Joiner WHERE_JOINER = Joiner.on(" and "); private static final Joiner PROJ_JOINER = Joiner.on(", "); private SchemaResult info; private StructType schema; - private final Location location; private final FlightClientOptions clientOptions; + private final FlightClientFactory clientFactory; + private final FlightClient client; private FlightDescriptor descriptor; private String sql; private Filter[] pushed; public FlightScanBuilder(Location location, FlightClientOptions clientOptions, String sql) { - this.location = location; this.clientOptions = clientOptions; this.sql = sql; descriptor = getDescriptor(sql); - try (FlightClientFactory clientFactory = new FlightClientFactory(location, clientOptions)) { - try (FlightClient client = clientFactory.apply()) { - info = client.getSchema(descriptor); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } + this.clientFactory = new FlightClientFactory(location, clientOptions); + this.client = clientFactory.apply(); + info = client.getSchema(descriptor); } @Override public Scan build() { - try (FlightClientFactory clientFactory = new FlightClientFactory(location, clientOptions)) { - try (FlightClient client = clientFactory.apply()) { - FlightInfo info = client.getInfo(FlightDescriptor.command(sql.getBytes())); - return new FlightScan(readSchema(), info, clientOptions); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } + FlightInfo info = client.getInfo(FlightDescriptor.command(sql.getBytes())); + return new FlightScan(readSchema(), info, clientOptions); } private boolean canBePushed(Filter filter) { @@ -156,13 +147,7 @@ public Filter[] pushFilters(Filter[] filters) { if (!pushed.isEmpty()) { String whereClause = generateWhereClause(pushed); mergeWhereDescriptors(whereClause); - try (FlightClientFactory clientFactory = new FlightClientFactory(location, clientOptions)) { - try (FlightClient client = clientFactory.apply()) { - info = client.getSchema(descriptor); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } + info = client.getSchema(descriptor); } return notPushed.toArray(new Filter[0]); } @@ -274,13 +259,13 @@ public void pruneColumns(StructType requiredSchema) { if (!fieldNames.isEmpty()) { this.schema = new StructType(fieldsLeft.toArray(new StructField[0])); mergeProjDescriptors(PROJ_JOINER.join(fields)); - try (FlightClientFactory clientFactory = new FlightClientFactory(location, clientOptions)) { - try (FlightClient client = clientFactory.apply()) { - info = client.getSchema(descriptor); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } + info = client.getSchema(descriptor); } } + + @Override + public void close() throws Exception { + // This order is important + AutoCloseables.close(client, clientFactory); + } } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightTable.java b/src/main/java/org/apache/arrow/flight/spark/FlightTable.java index 6fd0548..cfff266 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightTable.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightTable.java @@ -33,7 +33,11 @@ public String name() { @Override public StructType schema() { if (schema == null) { - schema = (new FlightScanBuilder(location, clientOptions, sql)).readSchema(); + try (FlightScanBuilder scanBuilder = new FlightScanBuilder(location, clientOptions, sql)) { + schema = scanBuilder.readSchema(); + } catch (Exception e) { + throw new RuntimeException(e); + } } return schema; } diff --git a/src/test/java/org/apache/arrow/flight/spark/TestConnector.java b/src/test/java/org/apache/arrow/flight/spark/TestConnector.java index 42c2a2f..40e93b2 100644 --- a/src/test/java/org/apache/arrow/flight/spark/TestConnector.java +++ b/src/test/java/org/apache/arrow/flight/spark/TestConnector.java @@ -15,6 +15,9 @@ */ package org.apache.arrow.flight.spark; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectOutputStream; import java.util.Iterator; import java.util.List; import java.util.Optional; @@ -50,6 +53,7 @@ import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; +import org.junit.Test.None; import com.google.common.collect.ImmutableList; @@ -97,6 +101,32 @@ public static void tearDown() throws Exception { AutoCloseables.close(server, allocator, sc); } + private class DummyObjectOutputStream extends ObjectOutputStream { + public DummyObjectOutputStream() throws IOException { + super(new ByteArrayOutputStream()); + } + } + + @Test(expected = None.class) + public void testFlightPartitionReaderFactorySerialization() throws IOException { + FlightClientOptions clientOptions = new FlightClientOptions("FooBar"); + FlightPartitionReaderFactory readerFactory = new FlightPartitionReaderFactory(clientOptions); + + try (ObjectOutputStream oos = new DummyObjectOutputStream()) { + oos.writeObject(readerFactory); + } + } + + @Test(expected = None.class) + public void testFlightPartitionSerialization() throws IOException { + Ticket ticket = new Ticket("FooBar".getBytes()); + FlightEndpoint endpoint = new FlightEndpoint(ticket, location); + FlightPartition partition = new FlightPartition(new FlightEndpointWrapper(endpoint)); + try (ObjectOutputStream oos = new DummyObjectOutputStream()) { + oos.writeObject(partition); + } + } + @Test public void testConnect() { csc.read("test.table"); @@ -143,17 +173,6 @@ public void testProject() { Assert.assertTrue(count < countOriginal); } - @Test - public void testParallel() { - String easySql = "select * from \"@dremio\".tpch_spark limit 100000"; - SizeConsumer c = new SizeConsumer(); - csc.readSql(easySql, true).toLocalIterator().forEachRemaining(c); - long width = c.width; - long length = c.length; - Assert.assertEquals(5, width); - Assert.assertEquals(40, length); - } - private static class TestProducer extends NoOpFlightProducer { private boolean parallel = false; @@ -167,7 +186,7 @@ public void doAction(CallContext context, Action action, StreamListener @Override public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { Schema schema; - List endpoints; + List endpoints; if (parallel) { endpoints = ImmutableList.of(new FlightEndpoint(new Ticket(descriptor.getCommand()), location), new FlightEndpoint(new Ticket(descriptor.getCommand()), location)); From f492649ecbac4a1f5443f9c350f61ed72e24efcf Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Sat, 16 Apr 2022 11:02:41 -0400 Subject: [PATCH 23/38] Fix lightScanBuilder resource leak. --- pom.xml | 6 +- .../arrow/flight/spark/FlightScanBuilder.java | 61 +++++++++++++------ .../arrow/flight/spark/FlightTable.java | 7 +-- 3 files changed, 48 insertions(+), 26 deletions(-) diff --git a/pom.xml b/pom.xml index 27e6e99..c37a104 100644 --- a/pom.xml +++ b/pom.xml @@ -38,7 +38,7 @@ kr.motd.maven os-maven-plugin - 1.5.0.Final + 1.7.0 @@ -302,8 +302,8 @@ limitations under the License. maven-compiler-plugin 3.8.1 - 8 - 8 + 9 + 9 diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java b/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java index 6080737..e93fb28 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java @@ -44,32 +44,60 @@ import com.google.common.collect.Lists; import com.google.common.base.Joiner; -public class FlightScanBuilder implements ScanBuilder, SupportsPushDownRequiredColumns, SupportsPushDownFilters, AutoCloseable { +public class FlightScanBuilder implements ScanBuilder, SupportsPushDownRequiredColumns, SupportsPushDownFilters { private static final Logger LOGGER = LoggerFactory.getLogger(FlightScanBuilder.class); private static final Joiner WHERE_JOINER = Joiner.on(" and "); private static final Joiner PROJ_JOINER = Joiner.on(", "); - private SchemaResult info; + private SchemaResult flightSchema; private StructType schema; + private final Location location; private final FlightClientOptions clientOptions; - private final FlightClientFactory clientFactory; - private final FlightClient client; private FlightDescriptor descriptor; private String sql; private Filter[] pushed; public FlightScanBuilder(Location location, FlightClientOptions clientOptions, String sql) { + this.location = location; this.clientOptions = clientOptions; this.sql = sql; descriptor = getDescriptor(sql); - this.clientFactory = new FlightClientFactory(location, clientOptions); - this.client = clientFactory.apply(); - info = client.getSchema(descriptor); + } + + private class Client implements AutoCloseable { + private final FlightClientFactory clientFactory; + private final FlightClient client; + + public Client(Location location, FlightClientOptions clientOptions) { + this.clientFactory = new FlightClientFactory(location, clientOptions); + this.client = clientFactory.apply(); + } + + public FlightClient get() { + return client; + } + + @Override + public void close() throws Exception { + AutoCloseables.close(client, clientFactory); + } + } + + private void getFlightSchema(FlightDescriptor descriptor) { + try (Client client = new Client(location, clientOptions)) { + flightSchema = client.get().getSchema(descriptor); + } catch (Exception e) { + throw new RuntimeException(e); + } } @Override public Scan build() { - FlightInfo info = client.getInfo(FlightDescriptor.command(sql.getBytes())); - return new FlightScan(readSchema(), info, clientOptions); + try (Client client = new Client(location, clientOptions)) { + FlightInfo info = client.get().getInfo(FlightDescriptor.command(sql.getBytes())); + return new FlightScan(readSchema(), info, clientOptions); + } catch (Exception e) { + throw new RuntimeException(e); + } } private boolean canBePushed(Filter filter) { @@ -147,7 +175,7 @@ public Filter[] pushFilters(Filter[] filters) { if (!pushed.isEmpty()) { String whereClause = generateWhereClause(pushed); mergeWhereDescriptors(whereClause); - info = client.getSchema(descriptor); + getFlightSchema(descriptor); } return notPushed.toArray(new Filter[0]); } @@ -217,7 +245,10 @@ private DataType sparkFromArrow(FieldType fieldType) { } private StructType readSchemaImpl() { - StructField[] fields = info.getSchema().getFields().stream() + if (flightSchema == null) { + getFlightSchema(descriptor); + } + StructField[] fields = flightSchema.getSchema().getFields().stream() .map(field -> new StructField(field.getName(), sparkFromArrow(field.getFieldType()), field.isNullable(), @@ -259,13 +290,7 @@ public void pruneColumns(StructType requiredSchema) { if (!fieldNames.isEmpty()) { this.schema = new StructType(fieldsLeft.toArray(new StructField[0])); mergeProjDescriptors(PROJ_JOINER.join(fields)); - info = client.getSchema(descriptor); + getFlightSchema(descriptor); } } - - @Override - public void close() throws Exception { - // This order is important - AutoCloseables.close(client, clientFactory); - } } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightTable.java b/src/main/java/org/apache/arrow/flight/spark/FlightTable.java index cfff266..9027994 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightTable.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightTable.java @@ -33,11 +33,8 @@ public String name() { @Override public StructType schema() { if (schema == null) { - try (FlightScanBuilder scanBuilder = new FlightScanBuilder(location, clientOptions, sql)) { - schema = scanBuilder.readSchema(); - } catch (Exception e) { - throw new RuntimeException(e); - } + FlightScanBuilder scanBuilder = new FlightScanBuilder(location, clientOptions, sql); + schema = scanBuilder.readSchema(); } return schema; } From 882e0906e540cfdd46aaa67a8518ab4df6ffa047 Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Sat, 16 Apr 2022 11:06:10 -0400 Subject: [PATCH 24/38] Change to Java 11. --- pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index c37a104..07f01d1 100644 --- a/pom.xml +++ b/pom.xml @@ -302,8 +302,8 @@ limitations under the License. maven-compiler-plugin 3.8.1 - 9 - 9 + 11 + 11 From 8400ca2f7805fa0643009275b34be45d8d73a970 Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Mon, 18 Apr 2022 12:39:23 -0400 Subject: [PATCH 25/38] Upgrade maven shade plugin. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 07f01d1..bd9063d 100644 --- a/pom.xml +++ b/pom.xml @@ -334,7 +334,7 @@ limitations under the License. org.apache.maven.plugins maven-shade-plugin - 3.1.0 + 3.3.0 package From c7f1f854d7801a8684cfea560d2181d1f6ae1e5b Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Tue, 19 Apr 2022 13:30:14 -0400 Subject: [PATCH 26/38] Fix shaded jar includes. --- pom.xml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index bd9063d..4491f58 100644 --- a/pom.xml +++ b/pom.xml @@ -32,6 +32,8 @@ 3.2.1 1.7.25 2.12.3 + UTF-8 + UTF-8 @@ -348,7 +350,8 @@ limitations under the License. org.apache.arrow:flight-grpc org.apache.arrow:arrow-vector org.apache.arrow:arrow-format - org.apache.arrow:arrow-memory + org.apache.arrow:arrow-memory-core + org.apache.arrow:arrow-memory-netty com.google.flatbuffers:flatbuffers-java io.grpc:* io.netty:* @@ -359,6 +362,7 @@ limitations under the License. com.google.api.grpc:proto-google-common-protos com.google.protobuf:protobuf-java com.google.guava:guava + com.google.guava:failureaccess io.perfmark:perfmark-api From 0a3a3c9e7279a00330f18bd02833b9517e3c1a85 Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Wed, 20 Apr 2022 11:51:55 -0400 Subject: [PATCH 27/38] Add full FlightClient options including client certs; Fix deprecated SQLContext usage. --- .../arrow/flight/spark/DefaultSource.java | 36 ++++++++++----- .../flight/spark/FlightClientFactory.java | 26 ++++++++--- .../flight/spark/FlightClientOptions.java | 26 ++++++++++- .../spark/FlightPartitionReaderFactory.java | 9 ++-- .../apache/arrow/flight/spark/FlightScan.java | 5 +- .../arrow/flight/spark/FlightScanBuilder.java | 9 ++-- .../flight/spark/FlightSparkContext.java | 46 ++++++------------- .../arrow/flight/spark/FlightTable.java | 5 +- .../arrow/flight/spark/TestConnector.java | 32 ++++++------- 9 files changed, 112 insertions(+), 82 deletions(-) diff --git a/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java index b175d98..2b5aa1e 100644 --- a/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java +++ b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java @@ -1,36 +1,48 @@ package org.apache.arrow.flight.spark; +import java.net.URISyntaxException; import java.util.Map; import org.apache.arrow.flight.Location; import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.sources.DataSourceRegister; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.broadcast.Broadcast; public class DefaultSource implements TableProvider, DataSourceRegister { + private SparkSession spark; + + private SparkSession getSparkSession() { + if (spark == null) { + spark = SparkSession.getActiveSession().get(); + } + return spark; + } private FlightTable makeTable(CaseInsensitiveStringMap options) { - String protocol = options.getOrDefault("protocol", "grpc"); + String uri = options.getOrDefault("uri", "grpc://localhost:47470"); Location location; - if (protocol == "grpc+tls") { - location = Location.forGrpcTls( - options.getOrDefault("host", "localhost"), - Integer.parseInt(options.getOrDefault("port", "47470")) - ); - } else { - location = Location.forGrpcInsecure( - options.getOrDefault("host", "localhost"), - Integer.parseInt(options.getOrDefault("port", "47470")) - ); + try { + location = new Location(uri); + } catch (URISyntaxException e) { + throw new RuntimeException(e); } String sql = options.getOrDefault("path", ""); + String username = options.getOrDefault("username", ""); + String password = options.getOrDefault("password", ""); String trustedCertificates = options.getOrDefault("trustedCertificates", ""); + String clientCertificate = options.getOrDefault("clientCertificate", ""); + String clientKey = options.getOrDefault("clientKey", ""); - FlightClientOptions clientOptions = trustedCertificates.isEmpty() ? null : new FlightClientOptions(trustedCertificates); + Broadcast clientOptions = JavaSparkContext.fromSparkContext(getSparkSession().sparkContext()).broadcast( + new FlightClientOptions(username, password, trustedCertificates, clientCertificate, clientKey) + ); return new FlightTable( String.format("{} Location {} Command {}", shortName(), location.getUri().toString(), sql), diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java index 835f050..75cc965 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java @@ -1,3 +1,4 @@ +// Portions of this file from: /* * Copyright (C) 2019 Ryan Murray * @@ -26,21 +27,32 @@ public class FlightClientFactory implements AutoCloseable { private final BufferAllocator allocator = new RootAllocator(); private final Location defaultLocation; - private InputStream trustedCertificates; + private final FlightClientOptions clientOptions; public FlightClientFactory(Location defaultLocation, FlightClientOptions clientOptions) { this.defaultLocation = defaultLocation; - if (clientOptions != null) { - this.trustedCertificates = new ByteArrayInputStream(clientOptions.getTrustedCertificates().getBytes()); - } + this.clientOptions = clientOptions; } public FlightClient apply() { FlightClient.Builder builder = FlightClient.builder(allocator, defaultLocation); - if (trustedCertificates != null) { - builder.trustedCertificates(trustedCertificates); + + if (!clientOptions.getTrustedCertificates().isEmpty()) { + builder.trustedCertificates(new ByteArrayInputStream(clientOptions.getTrustedCertificates().getBytes())); + } + + if (!clientOptions.getClientCertificate().isEmpty()) { + InputStream clientCert = new ByteArrayInputStream(clientOptions.getClientCertificate().getBytes()); + InputStream clientKey = new ByteArrayInputStream(clientOptions.getClientKey().getBytes()); + builder.clientCertificate(clientCert, clientKey); } - return builder.build(); + + FlightClient client = builder.build(); + if (!clientOptions.getUsername().isEmpty()) { + client.authenticateBasic(clientOptions.getUsername(), clientOptions.getPassword()); + } + + return client; } @Override diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java b/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java index 11cec47..6da610b 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java @@ -3,13 +3,37 @@ import java.io.Serializable; public class FlightClientOptions implements Serializable { + private final String username; + private final String password; private final String trustedCertificates; + private final String clientCertificate; + private final String clientKey; - public FlightClientOptions(String trustedCertificates) { + public FlightClientOptions(String username, String password, String trustedCertificates, String clientCertificate, String clientKey) { + this.username = username; + this.password = password; this.trustedCertificates = trustedCertificates; + this.clientCertificate = clientCertificate; + this.clientKey = clientKey; + } + + public String getUsername() { + return username; + } + + public String getPassword() { + return password; } public String getTrustedCertificates() { return trustedCertificates; } + + public String getClientCertificate() { + return clientCertificate; + } + + public String getClientKey() { + return clientKey; + } } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java index 3a8f45f..bff3e27 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java @@ -1,5 +1,6 @@ package org.apache.arrow.flight.spark; +import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.PartitionReader; @@ -7,9 +8,9 @@ import org.apache.spark.sql.vectorized.ColumnarBatch; public class FlightPartitionReaderFactory implements PartitionReaderFactory { - private final FlightClientOptions clientOptions; + private final Broadcast clientOptions; - public FlightPartitionReaderFactory(FlightClientOptions clientOptions) { + public FlightPartitionReaderFactory(Broadcast clientOptions) { this.clientOptions = clientOptions; } @@ -17,14 +18,14 @@ public FlightPartitionReaderFactory(FlightClientOptions clientOptions) { public PartitionReader createReader(InputPartition iPartition) { // This feels wrong but this is what upstream spark sources do to. FlightPartition partition = (FlightPartition) iPartition; - return new FlightPartitionReader(clientOptions, partition); + return new FlightPartitionReader(clientOptions.getValue(), partition); } @Override public PartitionReader createColumnarReader(InputPartition iPartition) { // This feels wrong but this is what upstream spark sources do to. FlightPartition partition = (FlightPartition) iPartition; - return new FlightColumnarPartitionReader(clientOptions, partition); + return new FlightColumnarPartitionReader(clientOptions.getValue(), partition); } @Override diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightScan.java b/src/main/java/org/apache/arrow/flight/spark/FlightScan.java index de84644..e172e5e 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightScan.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightScan.java @@ -3,6 +3,7 @@ import org.apache.spark.sql.connector.read.Scan; import org.apache.arrow.flight.FlightInfo; +import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.connector.read.Batch; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.PartitionReaderFactory; @@ -11,9 +12,9 @@ public class FlightScan implements Scan, Batch { private final StructType schema; private final FlightInfo info; - private final FlightClientOptions clientOptions; + private final Broadcast clientOptions; - public FlightScan(StructType schema, FlightInfo info, FlightClientOptions clientOptions) { + public FlightScan(StructType schema, FlightInfo info, Broadcast clientOptions) { this.schema = schema; this.info = info; this.clientOptions = clientOptions; diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java b/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java index e93fb28..ecdf079 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java @@ -30,6 +30,7 @@ import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.connector.read.Scan; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.connector.read.SupportsPushDownFilters; @@ -51,12 +52,12 @@ public class FlightScanBuilder implements ScanBuilder, SupportsPushDownRequiredC private SchemaResult flightSchema; private StructType schema; private final Location location; - private final FlightClientOptions clientOptions; + private final Broadcast clientOptions; private FlightDescriptor descriptor; private String sql; private Filter[] pushed; - public FlightScanBuilder(Location location, FlightClientOptions clientOptions, String sql) { + public FlightScanBuilder(Location location, Broadcast clientOptions, String sql) { this.location = location; this.clientOptions = clientOptions; this.sql = sql; @@ -83,7 +84,7 @@ public void close() throws Exception { } private void getFlightSchema(FlightDescriptor descriptor) { - try (Client client = new Client(location, clientOptions)) { + try (Client client = new Client(location, clientOptions.getValue())) { flightSchema = client.get().getSchema(descriptor); } catch (Exception e) { throw new RuntimeException(e); @@ -92,7 +93,7 @@ private void getFlightSchema(FlightDescriptor descriptor) { @Override public Scan build() { - try (Client client = new Client(location, clientOptions)) { + try (Client client = new Client(location, clientOptions.getValue())) { FlightInfo info = client.get().getInfo(FlightDescriptor.command(sql.getBytes())); return new FlightScan(readSchema(), info, clientOptions); } catch (Exception e) { diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightSparkContext.java b/src/main/java/org/apache/arrow/flight/spark/FlightSparkContext.java index c6d65ec..062bbc5 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightSparkContext.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightSparkContext.java @@ -16,61 +16,41 @@ package org.apache.arrow.flight.spark; import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.DataFrameReader; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; public class FlightSparkContext { private SparkConf conf; - private final DataFrameReader reader; - private FlightSparkContext(SparkContext sc, SparkConf conf) { - SQLContext sqlContext = SQLContext.getOrCreate(sc); - this.conf = conf; - reader = sqlContext.read().format("org.apache.arrow.flight.spark"); - } + private final DataFrameReader reader; - public static FlightSparkContext flightContext(JavaSparkContext sc) { - return new FlightSparkContext(sc.sc(), sc.getConf()); + public FlightSparkContext(SparkSession spark) { + this.conf = spark.sparkContext().getConf(); + reader = spark.read().format("org.apache.arrow.flight.spark"); } public Dataset read(String s) { return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) - .option("host", conf.get("spark.flight.endpoint.host")) + .option("uri", String.format( + "grpc://%s:%s", + conf.get("spark.flight.endpoint.host"), + conf.get("spark.flight.endpoint.port"))) .option("username", conf.get("spark.flight.auth.username")) .option("password", conf.get("spark.flight.auth.password")) - .option("parallel", false) .load(s); } public Dataset readSql(String s) { return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) - .option("host", conf.get("spark.flight.endpoint.host")) - .option("username", conf.get("spark.flight.auth.username")) - .option("password", conf.get("spark.flight.auth.password")) - .option("parallel", false) - .load(s); - } - - public Dataset read(String s, boolean parallel) { - return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) - .option("host", conf.get("spark.flight.endpoint.host")) - .option("username", conf.get("spark.flight.auth.username")) - .option("password", conf.get("spark.flight.auth.password")) - .option("parallel", parallel) - .load(s); - } - - public Dataset readSql(String s, boolean parallel) { - return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port"))) - .option("host", conf.get("spark.flight.endpoint.host")) + .option("uri", String.format( + "grpc://%s:%s", + conf.get("spark.flight.endpoint.host"), + conf.get("spark.flight.endpoint.port"))) .option("username", conf.get("spark.flight.auth.username")) .option("password", conf.get("spark.flight.auth.password")) - .option("parallel", parallel) .load(s); } } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightTable.java b/src/main/java/org/apache/arrow/flight/spark/FlightTable.java index 9027994..ac11a35 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightTable.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightTable.java @@ -3,6 +3,7 @@ import java.util.Set; import org.apache.arrow.flight.Location; +import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.connector.catalog.SupportsRead; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCapability; @@ -15,10 +16,10 @@ public class FlightTable implements Table, SupportsRead { private final String name; private final Location location; private final String sql; - private final FlightClientOptions clientOptions; + private final Broadcast clientOptions; private StructType schema; - public FlightTable(String name, Location location, String sql, FlightClientOptions clientOptions) { + public FlightTable(String name, Location location, String sql, Broadcast clientOptions) { this.name = name; this.location = location; this.sql = sql; diff --git a/src/test/java/org/apache/arrow/flight/spark/TestConnector.java b/src/test/java/org/apache/arrow/flight/spark/TestConnector.java index 40e93b2..40f6da4 100644 --- a/src/test/java/org/apache/arrow/flight/spark/TestConnector.java +++ b/src/test/java/org/apache/arrow/flight/spark/TestConnector.java @@ -45,10 +45,10 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.Text; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; @@ -61,8 +61,7 @@ public class TestConnector { private static final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); private static Location location; private static FlightServer server; - private static SparkConf conf; - private static JavaSparkContext sc; + private static SparkSession spark; private static FlightSparkContext csc; @BeforeClass @@ -83,22 +82,21 @@ public boolean authenticate(ServerAuthSender outgoing, Iterator incoming }).build() ); location = server.getLocation(); - conf = new SparkConf() - .setAppName("flightTest") - .setMaster("local[*]") - .set("spark.driver.allowMultipleContexts", "true") - .set("spark.flight.endpoint.host", location.getUri().getHost()) - .set("spark.flight.endpoint.port", Integer.toString(location.getUri().getPort())) - .set("spark.flight.auth.username", "xxx") - .set("spark.flight.auth.password", "yyy") - ; - sc = new JavaSparkContext(conf); - csc = FlightSparkContext.flightContext(sc); + spark = SparkSession.builder() + .appName("flightTest") + .master("local[*]") + .config("spark.driver.allowMultipleContexts", "true") + .config("spark.flight.endpoint.host", location.getUri().getHost()) + .config("spark.flight.endpoint.port", Integer.toString(location.getUri().getPort())) + .config("spark.flight.auth.username", "xxx") + .config("spark.flight.auth.password", "yyy") + .getOrCreate(); + csc = new FlightSparkContext(spark); } @AfterClass public static void tearDown() throws Exception { - AutoCloseables.close(server, allocator, sc); + AutoCloseables.close(server, allocator, spark); } private class DummyObjectOutputStream extends ObjectOutputStream { @@ -109,8 +107,8 @@ public DummyObjectOutputStream() throws IOException { @Test(expected = None.class) public void testFlightPartitionReaderFactorySerialization() throws IOException { - FlightClientOptions clientOptions = new FlightClientOptions("FooBar"); - FlightPartitionReaderFactory readerFactory = new FlightPartitionReaderFactory(clientOptions); + FlightClientOptions clientOptions = new FlightClientOptions("xxx", "yyy", "FooBar", "FooBar", "FooBar"); + FlightPartitionReaderFactory readerFactory = new FlightPartitionReaderFactory(JavaSparkContext.fromSparkContext(spark.sparkContext()).broadcast(clientOptions)); try (ObjectOutputStream oos = new DummyObjectOutputStream()) { oos.writeObject(readerFactory); From 07c13c861d5374f77ad3fa0e2dbccf686935291f Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Thu, 21 Apr 2022 16:48:38 -0400 Subject: [PATCH 28/38] Implement Azure AD oauth client middleware --- pom.xml | 16 +++++++++ .../flight/spark/AADClientMiddleware.java | 34 +++++++++++++++++++ .../spark/AADClientMiddlewareFactory.java | 29 ++++++++++++++++ .../arrow/flight/spark/DefaultSource.java | 12 ++++++- .../flight/spark/FlightClientFactory.java | 3 ++ .../spark/FlightClientMiddlewareFactory.java | 9 +++++ .../flight/spark/FlightClientOptions.java | 9 ++++- .../arrow/flight/spark/TestConnector.java | 4 ++- 8 files changed, 113 insertions(+), 3 deletions(-) create mode 100644 src/main/java/org/apache/arrow/flight/spark/AADClientMiddleware.java create mode 100644 src/main/java/org/apache/arrow/flight/spark/AADClientMiddlewareFactory.java create mode 100644 src/main/java/org/apache/arrow/flight/spark/FlightClientMiddlewareFactory.java diff --git a/pom.xml b/pom.xml index 4491f58..1e36d79 100644 --- a/pom.xml +++ b/pom.xml @@ -32,6 +32,7 @@ 3.2.1 1.7.25 2.12.3 + 1.2.0 UTF-8 UTF-8 @@ -413,6 +414,17 @@ limitations under the License. + + + + com.azure + azure-sdk-bom + ${azure.version} + pom + import + + + org.apache.spark @@ -544,6 +556,10 @@ limitations under the License. 4.13.1 test + + com.azure + azure-identity + diff --git a/src/main/java/org/apache/arrow/flight/spark/AADClientMiddleware.java b/src/main/java/org/apache/arrow/flight/spark/AADClientMiddleware.java new file mode 100644 index 0000000..0f58d07 --- /dev/null +++ b/src/main/java/org/apache/arrow/flight/spark/AADClientMiddleware.java @@ -0,0 +1,34 @@ +package org.apache.arrow.flight.spark; + +import com.azure.core.credential.TokenCredential; +import com.azure.core.credential.TokenRequestContext; + +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightClientMiddleware; + +public class AADClientMiddleware implements FlightClientMiddleware { + private final TokenCredential crediential; + private final TokenRequestContext requestContext; + + public AADClientMiddleware(TokenCredential crediential, TokenRequestContext requestContext) { + this.crediential = crediential; + this.requestContext = requestContext; + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + outgoingHeaders.insert("authorization", String.format("Bearer %s", crediential.getToken(requestContext).block().getToken())); + } + + @Override + public void onHeadersReceived(CallHeaders incomingHeaders) { + // Nothing needed here + } + + @Override + public void onCallCompleted(CallStatus status) { + // Nothing needed here + } + +} diff --git a/src/main/java/org/apache/arrow/flight/spark/AADClientMiddlewareFactory.java b/src/main/java/org/apache/arrow/flight/spark/AADClientMiddlewareFactory.java new file mode 100644 index 0000000..0955a2c --- /dev/null +++ b/src/main/java/org/apache/arrow/flight/spark/AADClientMiddlewareFactory.java @@ -0,0 +1,29 @@ +package org.apache.arrow.flight.spark; + +import com.azure.core.credential.TokenCredential; +import com.azure.core.credential.TokenRequestContext; +import com.azure.identity.ClientSecretCredentialBuilder; + +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.FlightClientMiddleware; + +public class AADClientMiddlewareFactory implements FlightClientMiddlewareFactory { + private final String clientId; + private final String clientSecret; + private final String scope; + + public AADClientMiddlewareFactory(String clientId, String clientSecret, String scope) { + this.clientId = clientId; + this.clientSecret = clientSecret; + this.scope = scope; + } + + @Override + public FlightClientMiddleware onCallStarted(CallInfo info) { + TokenCredential crediential = new ClientSecretCredentialBuilder().clientId(clientId).clientSecret(clientSecret).build(); + TokenRequestContext context = new TokenRequestContext(); + context.addScopes(scope); + return new AADClientMiddleware(crediential, context); + } + +} diff --git a/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java index 2b5aa1e..35af23b 100644 --- a/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java +++ b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java @@ -1,6 +1,8 @@ package org.apache.arrow.flight.spark; import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import org.apache.arrow.flight.Location; @@ -39,9 +41,17 @@ private FlightTable makeTable(CaseInsensitiveStringMap options) { String trustedCertificates = options.getOrDefault("trustedCertificates", ""); String clientCertificate = options.getOrDefault("clientCertificate", ""); String clientKey = options.getOrDefault("clientKey", ""); + String clientId = options.getOrDefault("clientId", ""); + String clientSecret = options.getOrDefault("clientSecret", ""); + String scope = options.getOrDefault("clientScope", ""); + List middleware = new ArrayList<>(); + if (!clientId.isEmpty()) { + middleware.add(new AADClientMiddlewareFactory(clientId, clientSecret, scope)); + } + Broadcast clientOptions = JavaSparkContext.fromSparkContext(getSparkSession().sparkContext()).broadcast( - new FlightClientOptions(username, password, trustedCertificates, clientCertificate, clientKey) + new FlightClientOptions(username, password, trustedCertificates, clientCertificate, clientKey, middleware) ); return new FlightTable( diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java index 75cc965..4a7872b 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java @@ -47,6 +47,9 @@ public FlightClient apply() { builder.clientCertificate(clientCert, clientKey); } + // Add client middleware + clientOptions.getMiddleware().stream().forEach(middleware -> builder.intercept(middleware)); + FlightClient client = builder.build(); if (!clientOptions.getUsername().isEmpty()) { client.authenticateBasic(clientOptions.getUsername(), clientOptions.getPassword()); diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightClientMiddlewareFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightClientMiddlewareFactory.java new file mode 100644 index 0000000..2469df0 --- /dev/null +++ b/src/main/java/org/apache/arrow/flight/spark/FlightClientMiddlewareFactory.java @@ -0,0 +1,9 @@ +package org.apache.arrow.flight.spark; + +import java.io.Serializable; + +import org.apache.arrow.flight.FlightClientMiddleware; + +public interface FlightClientMiddlewareFactory extends FlightClientMiddleware.Factory, Serializable { + +} diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java b/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java index 6da610b..5c6b499 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java @@ -1,6 +1,7 @@ package org.apache.arrow.flight.spark; import java.io.Serializable; +import java.util.List; public class FlightClientOptions implements Serializable { private final String username; @@ -8,13 +9,15 @@ public class FlightClientOptions implements Serializable { private final String trustedCertificates; private final String clientCertificate; private final String clientKey; + private final List middleware; - public FlightClientOptions(String username, String password, String trustedCertificates, String clientCertificate, String clientKey) { + public FlightClientOptions(String username, String password, String trustedCertificates, String clientCertificate, String clientKey, List middleware) { this.username = username; this.password = password; this.trustedCertificates = trustedCertificates; this.clientCertificate = clientCertificate; this.clientKey = clientKey; + this.middleware = middleware; } public String getUsername() { @@ -36,4 +39,8 @@ public String getClientCertificate() { public String getClientKey() { return clientKey; } + + public List getMiddleware() { + return middleware; + } } diff --git a/src/test/java/org/apache/arrow/flight/spark/TestConnector.java b/src/test/java/org/apache/arrow/flight/spark/TestConnector.java index 40f6da4..5ddc85d 100644 --- a/src/test/java/org/apache/arrow/flight/spark/TestConnector.java +++ b/src/test/java/org/apache/arrow/flight/spark/TestConnector.java @@ -18,6 +18,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectOutputStream; +import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Optional; @@ -107,7 +108,8 @@ public DummyObjectOutputStream() throws IOException { @Test(expected = None.class) public void testFlightPartitionReaderFactorySerialization() throws IOException { - FlightClientOptions clientOptions = new FlightClientOptions("xxx", "yyy", "FooBar", "FooBar", "FooBar"); + List middleware = new ArrayList<>(); + FlightClientOptions clientOptions = new FlightClientOptions("xxx", "yyy", "FooBar", "FooBar", "FooBar", middleware); FlightPartitionReaderFactory readerFactory = new FlightPartitionReaderFactory(JavaSparkContext.fromSparkContext(spark.sparkContext()).broadcast(clientOptions)); try (ObjectOutputStream oos = new DummyObjectOutputStream()) { From 1fe1439e82cc08885f0564abc9ad5fdd735d0d99 Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Thu, 21 Apr 2022 17:45:28 -0400 Subject: [PATCH 29/38] Bind local spark tests to localhost. --- src/test/java/org/apache/arrow/flight/spark/TestConnector.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/test/java/org/apache/arrow/flight/spark/TestConnector.java b/src/test/java/org/apache/arrow/flight/spark/TestConnector.java index 5ddc85d..3d9fc0b 100644 --- a/src/test/java/org/apache/arrow/flight/spark/TestConnector.java +++ b/src/test/java/org/apache/arrow/flight/spark/TestConnector.java @@ -86,6 +86,7 @@ public boolean authenticate(ServerAuthSender outgoing, Iterator incoming spark = SparkSession.builder() .appName("flightTest") .master("local[*]") + .config("spark.driver.host", "127.0.0.1") .config("spark.driver.allowMultipleContexts", "true") .config("spark.flight.endpoint.host", location.getUri().getHost()) .config("spark.flight.endpoint.port", Integer.toString(location.getUri().getPort())) From da8b92ee94716bb88272dbc0b50d70c8c128eebe Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Thu, 21 Apr 2022 18:10:47 -0400 Subject: [PATCH 30/38] Add azure.identity to shaded jar. --- pom.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/pom.xml b/pom.xml index 1e36d79..24858c0 100644 --- a/pom.xml +++ b/pom.xml @@ -365,6 +365,7 @@ limitations under the License. com.google.guava:guava com.google.guava:failureaccess io.perfmark:perfmark-api + com.azure:azure-identity io.netty:netty-transport-native-unix-common From 552328a0c6889135ae3f5580fa13dda8b57d7c37 Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Tue, 26 Apr 2022 08:32:11 -0400 Subject: [PATCH 31/38] Switch to simple token passing. Fix Postgres SQL syntax error with no alias --- pom.xml | 34 +++++++------------ .../spark/AADClientMiddlewareFactory.java | 29 ---------------- .../arrow/flight/spark/DefaultSource.java | 8 ++--- .../arrow/flight/spark/FlightScanBuilder.java | 9 +++-- ...leware.java => TokenClientMiddleware.java} | 15 +++----- .../spark/TokenClientMiddlewareFactory.java | 18 ++++++++++ 6 files changed, 45 insertions(+), 68 deletions(-) delete mode 100644 src/main/java/org/apache/arrow/flight/spark/AADClientMiddlewareFactory.java rename src/main/java/org/apache/arrow/flight/spark/{AADClientMiddleware.java => TokenClientMiddleware.java} (52%) create mode 100644 src/main/java/org/apache/arrow/flight/spark/TokenClientMiddlewareFactory.java diff --git a/pom.xml b/pom.xml index 24858c0..29661dc 100644 --- a/pom.xml +++ b/pom.xml @@ -31,8 +31,6 @@ 7.0.0 3.2.1 1.7.25 - 2.12.3 - 1.2.0 UTF-8 UTF-8 @@ -365,13 +363,23 @@ limitations under the License. com.google.guava:guava com.google.guava:failureaccess io.perfmark:perfmark-api - com.azure:azure-identity io.netty:netty-transport-native-unix-common io.netty:netty-transport-native-epoll + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + com.google.protobuf @@ -402,7 +410,6 @@ limitations under the License. META-INF.native.netty_ META-INF.native.cdap_netty_ - @@ -415,17 +422,6 @@ limitations under the License. - - - - com.azure - azure-sdk-bom - ${azure.version} - pom - import - - - org.apache.spark @@ -489,12 +485,12 @@ limitations under the License. com.fasterxml.jackson.core jackson-core - ${jackson-core.version} + 2.12.6 com.fasterxml.jackson.core jackson-databind - ${jackson-core.version} + 2.12.6.1 org.apache.arrow @@ -557,10 +553,6 @@ limitations under the License. 4.13.1 test - - com.azure - azure-identity - diff --git a/src/main/java/org/apache/arrow/flight/spark/AADClientMiddlewareFactory.java b/src/main/java/org/apache/arrow/flight/spark/AADClientMiddlewareFactory.java deleted file mode 100644 index 0955a2c..0000000 --- a/src/main/java/org/apache/arrow/flight/spark/AADClientMiddlewareFactory.java +++ /dev/null @@ -1,29 +0,0 @@ -package org.apache.arrow.flight.spark; - -import com.azure.core.credential.TokenCredential; -import com.azure.core.credential.TokenRequestContext; -import com.azure.identity.ClientSecretCredentialBuilder; - -import org.apache.arrow.flight.CallInfo; -import org.apache.arrow.flight.FlightClientMiddleware; - -public class AADClientMiddlewareFactory implements FlightClientMiddlewareFactory { - private final String clientId; - private final String clientSecret; - private final String scope; - - public AADClientMiddlewareFactory(String clientId, String clientSecret, String scope) { - this.clientId = clientId; - this.clientSecret = clientSecret; - this.scope = scope; - } - - @Override - public FlightClientMiddleware onCallStarted(CallInfo info) { - TokenCredential crediential = new ClientSecretCredentialBuilder().clientId(clientId).clientSecret(clientSecret).build(); - TokenRequestContext context = new TokenRequestContext(); - context.addScopes(scope); - return new AADClientMiddleware(crediential, context); - } - -} diff --git a/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java index 35af23b..87241f4 100644 --- a/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java +++ b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java @@ -41,12 +41,10 @@ private FlightTable makeTable(CaseInsensitiveStringMap options) { String trustedCertificates = options.getOrDefault("trustedCertificates", ""); String clientCertificate = options.getOrDefault("clientCertificate", ""); String clientKey = options.getOrDefault("clientKey", ""); - String clientId = options.getOrDefault("clientId", ""); - String clientSecret = options.getOrDefault("clientSecret", ""); - String scope = options.getOrDefault("clientScope", ""); + String token = options.getOrDefault("token", ""); List middleware = new ArrayList<>(); - if (!clientId.isEmpty()) { - middleware.add(new AADClientMiddlewareFactory(clientId, clientSecret, scope)); + if (!token.isEmpty()) { + middleware.add(new TokenClientMiddlewareFactory(token)); } diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java b/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java index ecdf079..399ab40 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java @@ -85,6 +85,7 @@ public void close() throws Exception { private void getFlightSchema(FlightDescriptor descriptor) { try (Client client = new Client(location, clientOptions.getValue())) { + LOGGER.info("getSchema() descriptor: %s", descriptor); flightSchema = client.get().getSchema(descriptor); } catch (Exception e) { throw new RuntimeException(e); @@ -94,7 +95,9 @@ private void getFlightSchema(FlightDescriptor descriptor) { @Override public Scan build() { try (Client client = new Client(location, clientOptions.getValue())) { - FlightInfo info = client.get().getInfo(FlightDescriptor.command(sql.getBytes())); + FlightDescriptor descriptor = FlightDescriptor.command(sql.getBytes()); + LOGGER.info("getInfo() descriptor: %s", descriptor); + FlightInfo info = client.get().getInfo(descriptor); return new FlightScan(readSchema(), info, clientOptions); } catch (Exception e) { throw new RuntimeException(e); @@ -156,7 +159,7 @@ private FlightDescriptor getDescriptor(String sql) { } private void mergeWhereDescriptors(String whereClause) { - sql = String.format("select * from (%s) where %s", sql, whereClause); + sql = String.format("select * from (%s) as where_merge where %s", sql, whereClause); descriptor = getDescriptor(sql); } @@ -266,7 +269,7 @@ public StructType readSchema() { } private void mergeProjDescriptors(String projClause) { - sql = String.format("select %s from (%s)", projClause, sql); + sql = String.format("select %s from (%s) as proj_merge", projClause, sql); descriptor = getDescriptor(sql); } diff --git a/src/main/java/org/apache/arrow/flight/spark/AADClientMiddleware.java b/src/main/java/org/apache/arrow/flight/spark/TokenClientMiddleware.java similarity index 52% rename from src/main/java/org/apache/arrow/flight/spark/AADClientMiddleware.java rename to src/main/java/org/apache/arrow/flight/spark/TokenClientMiddleware.java index 0f58d07..4ff1828 100644 --- a/src/main/java/org/apache/arrow/flight/spark/AADClientMiddleware.java +++ b/src/main/java/org/apache/arrow/flight/spark/TokenClientMiddleware.java @@ -1,24 +1,19 @@ package org.apache.arrow.flight.spark; -import com.azure.core.credential.TokenCredential; -import com.azure.core.credential.TokenRequestContext; - import org.apache.arrow.flight.CallHeaders; import org.apache.arrow.flight.CallStatus; import org.apache.arrow.flight.FlightClientMiddleware; -public class AADClientMiddleware implements FlightClientMiddleware { - private final TokenCredential crediential; - private final TokenRequestContext requestContext; +public class TokenClientMiddleware implements FlightClientMiddleware { + private final String token; - public AADClientMiddleware(TokenCredential crediential, TokenRequestContext requestContext) { - this.crediential = crediential; - this.requestContext = requestContext; + public TokenClientMiddleware(String token) { + this.token = token; } @Override public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { - outgoingHeaders.insert("authorization", String.format("Bearer %s", crediential.getToken(requestContext).block().getToken())); + outgoingHeaders.insert("authorization", String.format("Bearer %s", token)); } @Override diff --git a/src/main/java/org/apache/arrow/flight/spark/TokenClientMiddlewareFactory.java b/src/main/java/org/apache/arrow/flight/spark/TokenClientMiddlewareFactory.java new file mode 100644 index 0000000..2f741fe --- /dev/null +++ b/src/main/java/org/apache/arrow/flight/spark/TokenClientMiddlewareFactory.java @@ -0,0 +1,18 @@ +package org.apache.arrow.flight.spark; + +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.FlightClientMiddleware; + +public class TokenClientMiddlewareFactory implements FlightClientMiddlewareFactory { + private final String token; + + public TokenClientMiddlewareFactory(String token) { + this.token = token; + } + + @Override + public FlightClientMiddleware onCallStarted(CallInfo info) { + return new TokenClientMiddleware(token); + } + +} From c0edd48ad4b4bdf7b843e5a974fe43dfdeb8207a Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Fri, 23 Sep 2022 05:50:29 -0400 Subject: [PATCH 32/38] Update CI to Java 11 --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 7fb7a88..fad37da 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,6 @@ -dist: xenial +dist: focal language: java -jdk: openjdk8 +jdk: openjdk11 cache: directories: - $HOME/.m2 From b9b2150b8c91cac4076b9a076249cd4c5f17cbbd Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Fri, 23 Sep 2022 06:38:01 -0400 Subject: [PATCH 33/38] Update license headers to submit back to open source upstream. --- .mvn/wrapper/MavenWrapperDownloader.java | 4 +-- .mvn/wrapper/maven-wrapper.properties | 16 ++++++++++ mvnw | 29 +++++++++---------- mvnw.cmd | 29 +++++++++---------- .../arrow/flight/spark/DefaultSource.java | 16 ++++++++++ .../flight/spark/FlightClientFactory.java | 1 - .../spark/FlightClientMiddlewareFactory.java | 16 ++++++++++ .../flight/spark/FlightClientOptions.java | 16 ++++++++++ .../spark/FlightColumnarPartitionReader.java | 16 ++++++++++ .../flight/spark/FlightEndpointWrapper.java | 16 ++++++++++ .../arrow/flight/spark/FlightPartition.java | 16 ++++++++++ .../flight/spark/FlightPartitionReader.java | 16 ++++++++++ .../spark/FlightPartitionReaderFactory.java | 16 ++++++++++ .../apache/arrow/flight/spark/FlightScan.java | 16 ++++++++++ .../arrow/flight/spark/FlightScanBuilder.java | 1 - .../arrow/flight/spark/FlightTable.java | 16 ++++++++++ .../flight/spark/TokenClientMiddleware.java | 16 ++++++++++ .../spark/TokenClientMiddlewareFactory.java | 16 ++++++++++ 18 files changed, 236 insertions(+), 36 deletions(-) diff --git a/.mvn/wrapper/MavenWrapperDownloader.java b/.mvn/wrapper/MavenWrapperDownloader.java index e76d1f3..2f60875 100644 --- a/.mvn/wrapper/MavenWrapperDownloader.java +++ b/.mvn/wrapper/MavenWrapperDownloader.java @@ -1,11 +1,11 @@ /* - * Copyright 2007-present the original author or authors. + * Copyright (C) 2019 Ryan Murray * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/.mvn/wrapper/maven-wrapper.properties b/.mvn/wrapper/maven-wrapper.properties index 2743cab..cc4dcfd 100644 --- a/.mvn/wrapper/maven-wrapper.properties +++ b/.mvn/wrapper/maven-wrapper.properties @@ -1,3 +1,19 @@ +# +# Copyright (C) 2019 Ryan Murray +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.6.3/apache-maven-3.6.3-bin.zip wrapperUrl=https://repo.maven.apache.org/maven2/io/takari/maven-wrapper/0.5.6/maven-wrapper-0.5.6.jar diff --git a/mvnw b/mvnw index a16b543..9afee33 100755 --- a/mvnw +++ b/mvnw @@ -1,22 +1,19 @@ #!/bin/sh -# ---------------------------------------------------------------------------- -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at # -# https://www.apache.org/licenses/LICENSE-2.0 +# Copyright (C) 2019 Ryan Murray +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ---------------------------------------------------------------------------- # ---------------------------------------------------------------------------- # Maven Start Up Batch script diff --git a/mvnw.cmd b/mvnw.cmd index c8d4337..f43d04c 100644 --- a/mvnw.cmd +++ b/mvnw.cmd @@ -1,21 +1,18 @@ -@REM ---------------------------------------------------------------------------- -@REM Licensed to the Apache Software Foundation (ASF) under one -@REM or more contributor license agreements. See the NOTICE file -@REM distributed with this work for additional information -@REM regarding copyright ownership. The ASF licenses this file -@REM to you under the Apache License, Version 2.0 (the -@REM "License"); you may not use this file except in compliance -@REM with the License. You may obtain a copy of the License at @REM -@REM https://www.apache.org/licenses/LICENSE-2.0 +@REM Copyright (C) 2019 Ryan Murray +@REM +@REM Licensed under the Apache License, Version 2.0 (the "License"); +@REM you may not use this file except in compliance with the License. +@REM You may obtain a copy of the License at +@REM +@REM http://www.apache.org/licenses/LICENSE-2.0 +@REM +@REM Unless required by applicable law or agreed to in writing, software +@REM distributed under the License is distributed on an "AS IS" BASIS, +@REM WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@REM See the License for the specific language governing permissions and +@REM limitations under the License. @REM -@REM Unless required by applicable law or agreed to in writing, -@REM software distributed under the License is distributed on an -@REM "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -@REM KIND, either express or implied. See the License for the -@REM specific language governing permissions and limitations -@REM under the License. -@REM ---------------------------------------------------------------------------- @REM ---------------------------------------------------------------------------- @REM Maven Start Up Batch script diff --git a/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java index 87241f4..9893255 100644 --- a/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java +++ b/src/main/java/org/apache/arrow/flight/spark/DefaultSource.java @@ -1,3 +1,19 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.arrow.flight.spark; import java.net.URISyntaxException; diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java index 4a7872b..3b940f3 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java @@ -1,4 +1,3 @@ -// Portions of this file from: /* * Copyright (C) 2019 Ryan Murray * diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightClientMiddlewareFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightClientMiddlewareFactory.java index 2469df0..55eeb8a 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightClientMiddlewareFactory.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightClientMiddlewareFactory.java @@ -1,3 +1,19 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.arrow.flight.spark; import java.io.Serializable; diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java b/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java index 5c6b499..6b2e01d 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java @@ -1,3 +1,19 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.arrow.flight.spark; import java.io.Serializable; diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightColumnarPartitionReader.java b/src/main/java/org/apache/arrow/flight/spark/FlightColumnarPartitionReader.java index a23aa23..f11d4e6 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightColumnarPartitionReader.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightColumnarPartitionReader.java @@ -1,3 +1,19 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.arrow.flight.spark; import java.io.IOException; diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightEndpointWrapper.java b/src/main/java/org/apache/arrow/flight/spark/FlightEndpointWrapper.java index 78df97a..5652383 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightEndpointWrapper.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightEndpointWrapper.java @@ -1,3 +1,19 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.arrow.flight.spark; import java.io.IOException; diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightPartition.java b/src/main/java/org/apache/arrow/flight/spark/FlightPartition.java index 51c8b8c..8c77225 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightPartition.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightPartition.java @@ -1,3 +1,19 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.arrow.flight.spark; import org.apache.spark.sql.connector.read.InputPartition; diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReader.java b/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReader.java index 70f4535..381e06e 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReader.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReader.java @@ -1,3 +1,19 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.arrow.flight.spark; import java.io.IOException; diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java b/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java index bff3e27..6988e01 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java @@ -1,3 +1,19 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.arrow.flight.spark; import org.apache.spark.broadcast.Broadcast; diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightScan.java b/src/main/java/org/apache/arrow/flight/spark/FlightScan.java index e172e5e..0adad1b 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightScan.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightScan.java @@ -1,3 +1,19 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.arrow.flight.spark; import org.apache.spark.sql.connector.read.Scan; diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java b/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java index 399ab40..85ac7f8 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java @@ -1,4 +1,3 @@ -// Portions of this file where taken from: /* * Copyright (C) 2019 Ryan Murray * diff --git a/src/main/java/org/apache/arrow/flight/spark/FlightTable.java b/src/main/java/org/apache/arrow/flight/spark/FlightTable.java index ac11a35..3be1187 100644 --- a/src/main/java/org/apache/arrow/flight/spark/FlightTable.java +++ b/src/main/java/org/apache/arrow/flight/spark/FlightTable.java @@ -1,3 +1,19 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.arrow.flight.spark; import java.util.Set; diff --git a/src/main/java/org/apache/arrow/flight/spark/TokenClientMiddleware.java b/src/main/java/org/apache/arrow/flight/spark/TokenClientMiddleware.java index 4ff1828..61f31c7 100644 --- a/src/main/java/org/apache/arrow/flight/spark/TokenClientMiddleware.java +++ b/src/main/java/org/apache/arrow/flight/spark/TokenClientMiddleware.java @@ -1,3 +1,19 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.arrow.flight.spark; import org.apache.arrow.flight.CallHeaders; diff --git a/src/main/java/org/apache/arrow/flight/spark/TokenClientMiddlewareFactory.java b/src/main/java/org/apache/arrow/flight/spark/TokenClientMiddlewareFactory.java index 2f741fe..f5a0fdb 100644 --- a/src/main/java/org/apache/arrow/flight/spark/TokenClientMiddlewareFactory.java +++ b/src/main/java/org/apache/arrow/flight/spark/TokenClientMiddlewareFactory.java @@ -1,3 +1,19 @@ +/* + * Copyright (C) 2019 Ryan Murray + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.arrow.flight.spark; import org.apache.arrow.flight.CallInfo; From b42dcdec6024c772ee18e7702e7fa9298255ea31 Mon Sep 17 00:00:00 2001 From: Kyle Brooks Date: Tue, 11 Oct 2022 12:42:41 -0400 Subject: [PATCH 34/38] Update copyright and Authors. --- .editorconfig | 2 +- .mvn/extensions.xml | 2 +- .mvn/wrapper/MavenWrapperDownloader.java | 2 +- .mvn/wrapper/maven-wrapper.properties | 2 +- AUTHORS | 9 +++++++++ mvnw | 2 +- mvnw.cmd | 2 +- pom.xml | 4 ++-- .../org/apache/arrow/flight/spark/DefaultSource.java | 2 +- .../arrow/flight/spark/FlightArrowColumnVector.java | 2 +- .../apache/arrow/flight/spark/FlightClientFactory.java | 2 +- .../flight/spark/FlightClientMiddlewareFactory.java | 2 +- .../apache/arrow/flight/spark/FlightClientOptions.java | 2 +- .../flight/spark/FlightColumnarPartitionReader.java | 2 +- .../apache/arrow/flight/spark/FlightEndpointWrapper.java | 2 +- .../org/apache/arrow/flight/spark/FlightPartition.java | 2 +- .../apache/arrow/flight/spark/FlightPartitionReader.java | 2 +- .../arrow/flight/spark/FlightPartitionReaderFactory.java | 2 +- .../java/org/apache/arrow/flight/spark/FlightScan.java | 2 +- .../org/apache/arrow/flight/spark/FlightScanBuilder.java | 2 +- .../apache/arrow/flight/spark/FlightSparkContext.java | 2 +- .../java/org/apache/arrow/flight/spark/FlightTable.java | 2 +- .../apache/arrow/flight/spark/TokenClientMiddleware.java | 2 +- .../arrow/flight/spark/TokenClientMiddlewareFactory.java | 2 +- .../spark/sql/execution/arrow/FlightArrowUtils.scala | 2 +- .../org/apache/arrow/flight/spark/TestConnector.java | 2 +- src/test/resources/logback-test.xml | 2 +- 27 files changed, 36 insertions(+), 27 deletions(-) create mode 100644 AUTHORS diff --git a/.editorconfig b/.editorconfig index 7cfe605..8ba2661 100644 --- a/.editorconfig +++ b/.editorconfig @@ -1,5 +1,5 @@ # -# Copyright (C) 2019 Ryan Murray +# Copyright (C) 2019 The flight-spark-source Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/.mvn/extensions.xml b/.mvn/extensions.xml index 431b4b5..05ed801 100644 --- a/.mvn/extensions.xml +++ b/.mvn/extensions.xml @@ -1,7 +1,7 @@