Skip to content

Commit fbc471b

Browse files
committed
refactoring
1 parent 502f8ad commit fbc471b

File tree

3 files changed

+171
-225
lines changed

3 files changed

+171
-225
lines changed

sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,125 @@
1717

1818
package org.apache.spark.sql.jdbc
1919

20+
import java.sql.Connection
21+
22+
import scala.collection.JavaConverters._
2023
import scala.collection.mutable.MutableList
2124

25+
import com.spotify.docker.client.messages.ContainerConfig
2226
import com.spotify.docker.client._
2327

28+
import org.apache.spark.SparkFunSuite
29+
import org.apache.spark.sql.test.SharedSQLContext
30+
import org.scalatest.BeforeAndAfterAll
31+
32+
abstract class DatabaseOnDocker {
33+
/**
34+
* The docker image to be pulled
35+
*/
36+
def imageName: String
37+
38+
/**
39+
* A Seq of environment variables in the form of VAR=value
40+
*/
41+
def env: Seq[String]
42+
43+
/**
44+
* jdbcUrl should be a lazy val or a function since `ip` it relies on is only available after
45+
* the docker container starts
46+
*/
47+
def jdbcUrl: String
48+
49+
private val docker: DockerClient = DockerClientFactory.get()
50+
private var containerId: String = null
51+
52+
lazy val ip = docker.inspectContainer(containerId).networkSettings.ipAddress
53+
54+
def start(): Unit = {
55+
while (true) {
56+
try {
57+
val config = ContainerConfig.builder()
58+
.image(imageName).env(env.asJava)
59+
.build()
60+
containerId = docker.createContainer(config).id
61+
docker.startContainer(containerId)
62+
return
63+
} catch {
64+
case e: ImageNotFoundException => retry(5)(docker.pull(imageName))
65+
}
66+
}
67+
}
68+
69+
private def retry[T](n: Int)(fn: => T): T = {
70+
try {
71+
fn
72+
} catch {
73+
case e if n > 1 =>
74+
retry(n - 1)(fn)
75+
}
76+
}
77+
78+
def close(): Unit = {
79+
docker.killContainer(containerId)
80+
docker.removeContainer(containerId)
81+
DockerClientFactory.close(docker)
82+
}
83+
}
84+
85+
abstract class DatabaseIntegrationSuite extends SparkFunSuite
86+
with BeforeAndAfterAll with SharedSQLContext {
87+
88+
def db: DatabaseOnDocker
89+
90+
def waitForDatabase(ip: String, maxMillis: Long) {
91+
val before = System.currentTimeMillis()
92+
var lastException: java.sql.SQLException = null
93+
while (true) {
94+
if (System.currentTimeMillis() > before + maxMillis) {
95+
throw new java.sql.SQLException(s"Database not up after $maxMillis ms.", lastException)
96+
}
97+
try {
98+
val conn = java.sql.DriverManager.getConnection(db.jdbcUrl)
99+
conn.close()
100+
return
101+
} catch {
102+
case e: java.sql.SQLException =>
103+
lastException = e
104+
java.lang.Thread.sleep(250)
105+
}
106+
}
107+
}
108+
109+
def setupDatabase(ip: String): Unit = {
110+
val conn: Connection = java.sql.DriverManager.getConnection(db.jdbcUrl)
111+
try {
112+
dataPreparation(conn)
113+
} finally {
114+
conn.close()
115+
}
116+
}
117+
118+
/**
119+
* Prepare databases and tables for testing
120+
*/
121+
def dataPreparation(connection: Connection)
122+
123+
override def beforeAll() {
124+
super.beforeAll()
125+
db.start()
126+
waitForDatabase(db.ip, 60000)
127+
setupDatabase(db.ip)
128+
}
129+
130+
override def afterAll() {
131+
try {
132+
db.close()
133+
} finally {
134+
super.afterAll()
135+
}
136+
}
137+
}
138+
24139
/**
25140
* A factory and morgue for DockerClient objects. In the DockerClient we use,
26141
* calling close() closes the desired DockerClient but also renders all other

sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala

Lines changed: 40 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -18,132 +18,44 @@
1818
package org.apache.spark.sql.jdbc
1919

2020
import java.math.BigDecimal
21-
import java.sql.{Date, Timestamp}
21+
import java.sql.{Connection, Date, Timestamp}
2222
import java.util.Properties
2323

24-
import org.scalatest.BeforeAndAfterAll
25-
26-
import com.spotify.docker.client.{ImageNotFoundException, DockerClient}
27-
import com.spotify.docker.client.messages.ContainerConfig
28-
29-
import org.apache.spark.SparkFunSuite
30-
import org.apache.spark.sql.test.SharedSQLContext
31-
32-
class MySQLDatabase {
33-
val docker: DockerClient = DockerClientFactory.get()
34-
var containerId: String = null
35-
36-
start()
37-
38-
def start(): Unit = {
39-
while (true) {
40-
try {
41-
val config = ContainerConfig.builder()
42-
.image("mysql").env("MYSQL_ROOT_PASSWORD=rootpass")
43-
.build()
44-
containerId = docker.createContainer(config).id
45-
docker.startContainer(containerId)
46-
return
47-
} catch {
48-
case e: ImageNotFoundException => retry(3)(docker.pull("mysql:latest"))
49-
}
50-
}
51-
}
52-
53-
private def retry[T](n: Int)(fn: => T): T = {
54-
try {
55-
fn
56-
} catch {
57-
case e if n > 1 =>
58-
retry(n - 1)(fn)
59-
}
60-
}
61-
62-
lazy val ip = docker.inspectContainer(containerId).networkSettings.ipAddress
63-
64-
def close(): Unit = {
65-
docker.killContainer(containerId)
66-
docker.removeContainer(containerId)
67-
DockerClientFactory.close(docker)
68-
}
69-
}
70-
71-
class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with SharedSQLContext {
72-
lazy val db: MySQLDatabase = new MySQLDatabase()
73-
var ip: String = null
74-
75-
def url(ip: String): String = url(ip, "mysql")
76-
def url(ip: String, db: String): String = s"jdbc:mysql://$ip:3306/$db?user=root&password=rootpass"
77-
78-
def waitForDatabase(ip: String, maxMillis: Long) {
79-
val before = System.currentTimeMillis()
80-
var lastException: java.sql.SQLException = null
81-
while (true) {
82-
if (System.currentTimeMillis() > before + maxMillis) {
83-
throw new java.sql.SQLException(s"Database not up after $maxMillis ms.", lastException)
84-
}
85-
try {
86-
val conn = java.sql.DriverManager.getConnection(url(ip))
87-
conn.close()
88-
return
89-
} catch {
90-
case e: java.sql.SQLException =>
91-
lastException = e
92-
java.lang.Thread.sleep(250)
93-
}
94-
}
95-
}
96-
97-
def setupDatabase(ip: String) {
98-
val conn = java.sql.DriverManager.getConnection(url(ip))
99-
try {
100-
conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
101-
conn.prepareStatement("CREATE TABLE foo.tbl (x INTEGER, y TEXT(8))").executeUpdate()
102-
conn.prepareStatement("INSERT INTO foo.tbl VALUES (42,'fred')").executeUpdate()
103-
conn.prepareStatement("INSERT INTO foo.tbl VALUES (17,'dave')").executeUpdate()
104-
105-
conn.prepareStatement("CREATE TABLE foo.numbers (onebit BIT(1), tenbits BIT(10), "
106-
+ "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, "
107-
+ "dbl DOUBLE)").executeUpdate()
108-
conn.prepareStatement("INSERT INTO foo.numbers VALUES (b'0', b'1000100101', "
109-
+ "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, "
110-
+ "42.75, 1.0000000000000002)").executeUpdate()
111-
112-
conn.prepareStatement("CREATE TABLE foo.dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, "
113-
+ "yr YEAR)").executeUpdate()
114-
conn.prepareStatement("INSERT INTO foo.dates VALUES ('1991-11-09', '13:31:24', "
115-
+ "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate()
116-
117-
// TODO: Test locale conversion for strings.
118-
conn.prepareStatement("CREATE TABLE foo.strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, "
119-
+ "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)"
120-
).executeUpdate()
121-
conn.prepareStatement("INSERT INTO foo.strings VALUES ('the', 'quick', 'brown', 'fox', " +
122-
"'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate()
123-
} finally {
124-
conn.close()
125-
}
126-
}
127-
128-
override def beforeAll() {
129-
// If you load the MySQL driver here, DriverManager will deadlock. The
130-
// MySQL driver gets loaded when its jar gets loaded, unlike the Postgres
131-
// and H2 drivers.
132-
// scalastyle:off classforname
133-
// Class.forName("com.mysql.jdbc.Driver")
134-
// scalastyle:on classforname
135-
super.beforeAll()
136-
waitForDatabase(db.ip, 60000)
137-
setupDatabase(db.ip)
138-
ip = db.ip
24+
class MySQLIntegrationSuite extends DatabaseIntegrationSuite {
25+
val db = new DatabaseOnDocker {
26+
val imageName = "mysql:latest"
27+
val env = Seq("MYSQL_ROOT_PASSWORD=rootpass")
28+
lazy val jdbcUrl = s"jdbc:mysql://$ip:3306/mysql?user=root&password=rootpass"
13929
}
14030

141-
override def afterAll() {
142-
db.close()
31+
override def dataPreparation(conn: Connection) {
32+
conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
33+
conn.prepareStatement("CREATE TABLE tbl (x INTEGER, y TEXT(8))").executeUpdate()
34+
conn.prepareStatement("INSERT INTO tbl VALUES (42,'fred')").executeUpdate()
35+
conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate()
36+
37+
conn.prepareStatement("CREATE TABLE numbers (onebit BIT(1), tenbits BIT(10), "
38+
+ "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, "
39+
+ "dbl DOUBLE)").executeUpdate()
40+
conn.prepareStatement("INSERT INTO numbers VALUES (b'0', b'1000100101', "
41+
+ "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, "
42+
+ "42.75, 1.0000000000000002)").executeUpdate()
43+
44+
conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, "
45+
+ "yr YEAR)").executeUpdate()
46+
conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', "
47+
+ "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate()
48+
49+
// TODO: Test locale conversion for strings.
50+
conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, "
51+
+ "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)"
52+
).executeUpdate()
53+
conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', 'fox', " +
54+
"'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate()
14355
}
14456

14557
test("Basic test") {
146-
val df = sqlContext.read.jdbc(url(ip, "foo"), "tbl", new Properties)
58+
val df = sqlContext.read.jdbc(db.jdbcUrl, "tbl", new Properties)
14759
val rows = df.collect()
14860
assert(rows.length == 2)
14961
val types = rows(0).toSeq.map(x => x.getClass.toString)
@@ -153,7 +65,7 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with Sh
15365
}
15466

15567
test("Numeric types") {
156-
val df = sqlContext.read.jdbc(url(ip, "foo"), "numbers", new Properties)
68+
val df = sqlContext.read.jdbc(db.jdbcUrl, "numbers", new Properties)
15769
val rows = df.collect()
15870
assert(rows.length == 1)
15971
val types = rows(0).toSeq.map(x => x.getClass.toString)
@@ -180,7 +92,7 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with Sh
18092
}
18193

18294
test("Date types") {
183-
val df = sqlContext.read.jdbc(url(ip, "foo"), "dates", new Properties)
95+
val df = sqlContext.read.jdbc(db.jdbcUrl, "dates", new Properties)
18496
val rows = df.collect()
18597
assert(rows.length == 1)
18698
val types = rows(0).toSeq.map(x => x.getClass.toString)
@@ -198,7 +110,7 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with Sh
198110
}
199111

200112
test("String types") {
201-
val df = sqlContext.read.jdbc(url(ip, "foo"), "strings", new Properties)
113+
val df = sqlContext.read.jdbc(db.jdbcUrl, "strings", new Properties)
202114
val rows = df.collect()
203115
assert(rows.length == 1)
204116
val types = rows(0).toSeq.map(x => x.getClass.toString)
@@ -224,11 +136,11 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with Sh
224136
}
225137

226138
test("Basic write test") {
227-
val df1 = sqlContext.read.jdbc(url(ip, "foo"), "numbers", new Properties)
228-
val df2 = sqlContext.read.jdbc(url(ip, "foo"), "dates", new Properties)
229-
val df3 = sqlContext.read.jdbc(url(ip, "foo"), "strings", new Properties)
230-
df1.write.jdbc(url(ip, "foo"), "numberscopy", new Properties)
231-
df2.write.jdbc(url(ip, "foo"), "datescopy", new Properties)
232-
df3.write.jdbc(url(ip, "foo"), "stringscopy", new Properties)
139+
val df1 = sqlContext.read.jdbc(db.jdbcUrl, "numbers", new Properties)
140+
val df2 = sqlContext.read.jdbc(db.jdbcUrl, "dates", new Properties)
141+
val df3 = sqlContext.read.jdbc(db.jdbcUrl, "strings", new Properties)
142+
df1.write.jdbc(db.jdbcUrl, "numberscopy", new Properties)
143+
df2.write.jdbc(db.jdbcUrl, "datescopy", new Properties)
144+
df3.write.jdbc(db.jdbcUrl, "stringscopy", new Properties)
233145
}
234146
}

0 commit comments

Comments
 (0)