Querying JDBC database in parallel with Google Dataflow (Apache Beam)

 

NOTE: The Java example code for this technical

blog can be found in this GitHub repo:

Consider the following situation: You want to use a single query to query a JDBC compatible database like Google Cloud SQL (MySQL) that contains millions of rows. You may want to do this when migrating legacy database data to BigQuery. At some point, the number of rows is so large, that a single query yields too many results for your virtual machines in Google Dataflow to handle. And the pipeline will hang and will not complete at all due to an out of memory problem.

The Challenge

The Java version of Apache Beam has the built-in function JdbcIO.read() I/O Transform that can read and write to a JDBC source. JdbcIO can read the source using a single query. However, if your query returns millions of rows, it will become slow and not complete due to memory issues in your machine. In the Apache Sparks JDBC connector, there is a query partition function that divides a query into partitions before executing it. But, running Spark means that you will need to provision virtual machines on Dataproc before you can run the pipeline. With Apache Beam you can run the pipeline directly using Google Dataflow and any provisioning of machines is done when you specify the pipeline parameters. It is a serverless, on-demand solution. Would it be possible to do something like this in Apache Beam? 

Building a partitioned JDBC query pipeline (Java Apache Beam).

Apache Beams JdbcIO.readAll() Transform can query a source in parallel, given a PCollection of query strings. In order to query a table in parallel, we need to construct queries that query ranges of a table. Consider for example a MySQL table with an auto-increment column ‘index_id’. See also the SQL DDL command for the example table:

 

CREATE TABLE `HelloWorld` (

  `ID` varchar(255) DEFAULT NULL,

  `Job Title` varchar(255) DEFAULT NULL,

  `Email Address` varchar(255) DEFAULT NULL,

  `FirstName LastName` varchar(255) DEFAULT NULL,

  `Address` varchar(255) DEFAULT NULL,

  `integer_column` varchar(255) DEFAULT NULL,

  `string_column` varchar(255) DEFAULT NULL,

  `index_id` int(11) NOT NULL AUTO_INCREMENT,

  PRIMARY KEY (`index_id`)

)

I have populated this table with 16 million rows of data inside a Google Cloud SQL instance. 

General step by step implementing partitioned parallel JdbcIO read

1) Add dependencies

A JDBC connection pool will make concurrent JDBC connections more efficient. I use c3p0 for this purpose:

   <dependency>

      <groupId>com.mchange</groupId>

      <artifactId>c3p0</artifactId>

      <version>0.9.5.4</version>

    </dependency>

We need to add JDBC dependencies for the Java Maven project in order to run the example. 

 <dependency>

      <groupId>org.apache.beam</groupId>

      <artifactId>beam-sdks-java-io-jdbc</artifactId>

      <version>${beam.version}</version>

    </dependency>

    <dependency>

      <groupId>mysql</groupId>

      <artifactId>mysql-connector-java</artifactId>

      <version>8.0.16</version>

    </dependency>

    <dependency>

      <groupId>com.google.cloud.sql</groupId>

      <artifactId>mysql-socket-factory-connector-j-8</artifactId>

      <version>1.0.15</version>

    </dependency>

2) Configure connection pool of the added dependency mentioned in 1)

ComboPooledDataSource dataSource = new ComboPooledDataSource();

      dataSource.setDriverClass("com.mysql.cj.jdbc.Driver");

      dataSource.setJdbcUrl("jdbc:mysql://google/<DATABASE_NAME>?cloudSqlInstance=
<INSTANCE_CONNECTION_NAME>" +

              "&socketFactory=com.google.cloud.sql.mysql.SocketFactory&useSSL=false" +

              "&user=<MYSQL_USER_NAME>&password=<MYSQL_USER_PASSWORD>");

      dataSource.setMaxPoolSize(10);

      dataSource.setInitialPoolSize(6);

      JdbcIO.DataSourceConfiguration config

              = JdbcIO.DataSourceConfiguration.create(dataSource);

3) We can now build the first part of the pipeline

Consider the “Read from Cloud SQL MySQL: HelloWorld” step of the pipeline below. I use the index_id field for pagination. I divide the table into partitioned ranges of chunks of 1000 elements. I use the query:

“SELECT MAX(`index_id`) from HelloWorld”

To get the number of rows from the table. This method is faster than using a SELECT COUNT(*) query. Note that this solution is for demonstration purposes only. The MySQL database is using the InnoDB engine which is slow when doing a SELECT COUNT(*) query with a large number of rows. If your table has deleted rows in the past, selecting the maximum value from index_id would not be accurate. In this case, you could store the number of rows of a table in a separate metadata table.

In the “Distribute” step of the pipeline, I convert the total number of rows into Strings of range indices separated by a comma to an Apache Beam Key Value class instance with String as key and Integer as value (KV<String,Integer>). The key value of the KV are the range indices that have been defined separated with a comma and the value of the KV is just a fixed integer 1. 

In the “Break Fusion” step I invoke a GroupByKey Transform. This will group the KVs of the earlier step by key. This step is needed in order to “break the fusion”. Apparently, if your pipeline contains operations that have a high number of output elements then the elements outputted will stay in the same machine to be processed. You need to redistribute this to all available machines in order to keep the parallel processing capabilities of Apache Beam. See here for more details on deploying a pipeline.

The code excerpt of these steps are displayed below:

PCollection<KV<String,Iterable<Integer>>> ranges =

            p.apply(String.format("Read from Cloud SQL MySQL: %s",tableName), JdbcIO.
<String>read()

            .withDataSourceConfiguration(config)

            .withQuery(String.format("SELECT MAX(`index_id`) from %s", tableName))

            .withRowMapper(new JdbcIO.RowMapper<String>() {

                public String mapRow(ResultSet resultSet) throws Exception {

                    return resultSet.getString(1);

                }

            })

            .withOutputParallelization(false)

            .withCoder(StringUtf8Coder.of()))

            .apply("Distribute", ParDo.of(new DoFn<String, KV<String, Integer>>() {

                @ProcessElement

                public void processElement(ProcessContext c) {

                    int readChunk = fetchSize;

                    int count = Integer.parseInt((String) c.element());

                    int ranges = (int) (count / readChunk);

                    for (int i = 0; i < ranges; i++) {

                        int indexFrom = i * readChunk;

                        int indexTo = (i + 1) * readChunk;

                        String range = String.format("%s,%s",indexFrom, indexTo);

                        KV<String,Integer> kvRange = KV.of(range, 1);

                        c.output(kvRange);

                    }

                    if (count > ranges * readChunk) {

                        int indexFrom = ranges * readChunk;

                        int indexTo = ranges * readChunk + count % readChunk;

                        String range = String.format("%s,%s",indexFrom, indexTo);

                        KV<String,Integer> kvRange = KV.of(range, 1);

                        c.output(kvRange);

                    }

                }

            }))

            .apply("Break Fusion", GroupByKey.create());

4) The last step is to use the created ranges to read JDBC the source in parallel 

For that I use the query:

select * from <DATABASE_NAME>.HelloWorld where index_id >= ? and index_id < ?

and a ParameterSetter. For illustration purposes, I added a RowMapper that maps the selected rows into a readable JSON String using Jackson. I also added a simple map function that does something after reading the data from the database in parallel. The code excerpt for these steps is displayed below:

ranges.apply(String.format("Read ALL %s", tableName), JdbcIO.<KV<String,Iterable<Integer>>,
String>readAll()

              .withDataSourceConfiguration(config)

              .withFetchSize(fetchSize)

              .withCoder(StringUtf8Coder.of())

              .withParameterSetter(new JdbcIO.PreparedStatementSetter<KV<String,Iterable<Integer>>>() {

                  @Override




                  public void setParameters(KV<String,Iterable<Integer>> element,

                                            PreparedStatement preparedStatement) throws Exception {




                      String[] range = element.getKey().split(",");

                      preparedStatement.setInt(1, Integer.parseInt(range[0]));

                      preparedStatement.setInt(2, Integer.parseInt(range[1]));

                  }

              })

                      .withOutputParallelization(false)

              .withQuery(String.format("select * from <DATABASE_NAME>.%s where index_id >= ? 
              and index_id < ?",tableName))

                      .withRowMapper((JdbcIO.RowMapper<String>) resultSet -> {

                          ObjectMapper mapper = new ObjectMapper();

                          ArrayNode arrayNode = mapper.createArrayNode();

                          for (int i = 1; i <= resultSet.getMetaData().getColumnCount(); i++) {

                              String columnTypeIntKey ="";

                              try {

                                  ObjectNode objectNode = mapper.createObjectNode();

                                  objectNode.put("column_name",

                                          resultSet.getMetaData().getColumnName(i));




                                  objectNode.put("value",

                                          resultSet.getString(i));

                                  arrayNode.add(objectNode);

                              } catch (Exception e) {

                                  LOG.error("problem columnTypeIntKey: " +  columnTypeIntKey);

                                  throw e;

                              }

                          }

                          return mapper.writeValueAsString(arrayNode);

                      })

              )

              .apply(MapElements.via(

                      new SimpleFunction<String, Integer>() {

                          @Override

                          public Integer apply(String line) {

                              return line.length();

                          }

                      }))

      ;

By now you have a pipeline that reads a JDBC source in parallel. In my example I got a throughput of over 250k elements per second with three n1-standard-8 machines:

JDBC

Conclusion

In short, this article explained how to read from a JDBC source using JdbcIO.readAll() transform of Apache Beam. The user can use the provided example to implement their own parallel JDBC read pipeline within Apache Beam. If you want to learn more about Apache Beam, click here to get all the Apache Beam resources.

If you want to use this solution for your own pipelines please look at the GitHub repo that contains the code: https://github.com/HocLengChung/Apache-Beam-Jdbc-Parallel-Read

Just change the parameters of the example to your needs.

Data Integration At Devoteam

Data and its value are growing exponentially in the 21st century. Major IT challenges are mostly related to data (and its massive growth). These challenges often include questions such as: How do I get a grip on data? Where is my data stored? Which (advanced) analytics possibilities are there to optimize its use? How reliable or trustworthy is my data? This begs for an integrated data strategy.

Related Technical blog-posts

devoteam

Contact

Hoc Leng Chung
Data Analytics Consultant