1818package org .apache .spark .sql .jdbc
1919
2020import java .math .BigDecimal
21- import java .sql .{Date , Timestamp }
21+ import java .sql .{Connection , Date , Timestamp }
2222import java .util .Properties
2323
24- import org .scalatest .BeforeAndAfterAll
25-
26- import com .spotify .docker .client .{ImageNotFoundException , DockerClient }
27- import com .spotify .docker .client .messages .ContainerConfig
28-
29- import org .apache .spark .SparkFunSuite
30- import org .apache .spark .sql .test .SharedSQLContext
31-
32- class MySQLDatabase {
33- val docker : DockerClient = DockerClientFactory .get()
34- var containerId : String = null
35-
36- start()
37-
38- def start (): Unit = {
39- while (true ) {
40- try {
41- val config = ContainerConfig .builder()
42- .image(" mysql" ).env(" MYSQL_ROOT_PASSWORD=rootpass" )
43- .build()
44- containerId = docker.createContainer(config).id
45- docker.startContainer(containerId)
46- return
47- } catch {
48- case e : ImageNotFoundException => retry(3 )(docker.pull(" mysql:latest" ))
49- }
50- }
51- }
52-
53- private def retry [T ](n : Int )(fn : => T ): T = {
54- try {
55- fn
56- } catch {
57- case e if n > 1 =>
58- retry(n - 1 )(fn)
59- }
60- }
61-
62- lazy val ip = docker.inspectContainer(containerId).networkSettings.ipAddress
63-
64- def close (): Unit = {
65- docker.killContainer(containerId)
66- docker.removeContainer(containerId)
67- DockerClientFactory .close(docker)
68- }
69- }
70-
71- class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with SharedSQLContext {
72- lazy val db : MySQLDatabase = new MySQLDatabase ()
73- var ip : String = null
74-
75- def url (ip : String ): String = url(ip, " mysql" )
76- def url (ip : String , db : String ): String = s " jdbc:mysql:// $ip:3306/ $db?user=root&password=rootpass "
77-
78- def waitForDatabase (ip : String , maxMillis : Long ) {
79- val before = System .currentTimeMillis()
80- var lastException : java.sql.SQLException = null
81- while (true ) {
82- if (System .currentTimeMillis() > before + maxMillis) {
83- throw new java.sql.SQLException (s " Database not up after $maxMillis ms. " , lastException)
84- }
85- try {
86- val conn = java.sql.DriverManager .getConnection(url(ip))
87- conn.close()
88- return
89- } catch {
90- case e : java.sql.SQLException =>
91- lastException = e
92- java.lang.Thread .sleep(250 )
93- }
94- }
95- }
96-
97- def setupDatabase (ip : String ) {
98- val conn = java.sql.DriverManager .getConnection(url(ip))
99- try {
100- conn.prepareStatement(" CREATE DATABASE foo" ).executeUpdate()
101- conn.prepareStatement(" CREATE TABLE foo.tbl (x INTEGER, y TEXT(8))" ).executeUpdate()
102- conn.prepareStatement(" INSERT INTO foo.tbl VALUES (42,'fred')" ).executeUpdate()
103- conn.prepareStatement(" INSERT INTO foo.tbl VALUES (17,'dave')" ).executeUpdate()
104-
105- conn.prepareStatement(" CREATE TABLE foo.numbers (onebit BIT(1), tenbits BIT(10), "
106- + " small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, "
107- + " dbl DOUBLE)" ).executeUpdate()
108- conn.prepareStatement(" INSERT INTO foo.numbers VALUES (b'0', b'1000100101', "
109- + " 17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, "
110- + " 42.75, 1.0000000000000002)" ).executeUpdate()
111-
112- conn.prepareStatement(" CREATE TABLE foo.dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, "
113- + " yr YEAR)" ).executeUpdate()
114- conn.prepareStatement(" INSERT INTO foo.dates VALUES ('1991-11-09', '13:31:24', "
115- + " '1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')" ).executeUpdate()
116-
117- // TODO: Test locale conversion for strings.
118- conn.prepareStatement(" CREATE TABLE foo.strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, "
119- + " d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)"
120- ).executeUpdate()
121- conn.prepareStatement(" INSERT INTO foo.strings VALUES ('the', 'quick', 'brown', 'fox', " +
122- " 'jumps', 'over', 'the', 'lazy', 'dog')" ).executeUpdate()
123- } finally {
124- conn.close()
125- }
126- }
127-
128- override def beforeAll () {
129- // If you load the MySQL driver here, DriverManager will deadlock. The
130- // MySQL driver gets loaded when its jar gets loaded, unlike the Postgres
131- // and H2 drivers.
132- // scalastyle:off classforname
133- // Class.forName("com.mysql.jdbc.Driver")
134- // scalastyle:on classforname
135- super .beforeAll()
136- waitForDatabase(db.ip, 60000 )
137- setupDatabase(db.ip)
138- ip = db.ip
24+ class MySQLIntegrationSuite extends DatabaseIntegrationSuite {
25+ val db = new DatabaseOnDocker {
26+ val imageName = " mysql:latest"
27+ val env = Seq (" MYSQL_ROOT_PASSWORD=rootpass" )
28+ lazy val jdbcUrl = s " jdbc:mysql:// $ip:3306/mysql?user=root&password=rootpass "
13929 }
14030
141- override def afterAll () {
142- db.close()
31+ override def dataPreparation (conn : Connection ) {
32+ conn.prepareStatement(" CREATE DATABASE foo" ).executeUpdate()
33+ conn.prepareStatement(" CREATE TABLE tbl (x INTEGER, y TEXT(8))" ).executeUpdate()
34+ conn.prepareStatement(" INSERT INTO tbl VALUES (42,'fred')" ).executeUpdate()
35+ conn.prepareStatement(" INSERT INTO tbl VALUES (17,'dave')" ).executeUpdate()
36+
37+ conn.prepareStatement(" CREATE TABLE numbers (onebit BIT(1), tenbits BIT(10), "
38+ + " small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, "
39+ + " dbl DOUBLE)" ).executeUpdate()
40+ conn.prepareStatement(" INSERT INTO numbers VALUES (b'0', b'1000100101', "
41+ + " 17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, "
42+ + " 42.75, 1.0000000000000002)" ).executeUpdate()
43+
44+ conn.prepareStatement(" CREATE TABLE dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, "
45+ + " yr YEAR)" ).executeUpdate()
46+ conn.prepareStatement(" INSERT INTO dates VALUES ('1991-11-09', '13:31:24', "
47+ + " '1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')" ).executeUpdate()
48+
49+ // TODO: Test locale conversion for strings.
50+ conn.prepareStatement(" CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, "
51+ + " d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)"
52+ ).executeUpdate()
53+ conn.prepareStatement(" INSERT INTO strings VALUES ('the', 'quick', 'brown', 'fox', " +
54+ " 'jumps', 'over', 'the', 'lazy', 'dog')" ).executeUpdate()
14355 }
14456
14557 test(" Basic test" ) {
146- val df = sqlContext.read.jdbc(url(ip, " foo " ) , " tbl" , new Properties )
58+ val df = sqlContext.read.jdbc(db.jdbcUrl , " tbl" , new Properties )
14759 val rows = df.collect()
14860 assert(rows.length == 2 )
14961 val types = rows(0 ).toSeq.map(x => x.getClass.toString)
@@ -153,7 +65,7 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with Sh
15365 }
15466
15567 test(" Numeric types" ) {
156- val df = sqlContext.read.jdbc(url(ip, " foo " ) , " numbers" , new Properties )
68+ val df = sqlContext.read.jdbc(db.jdbcUrl , " numbers" , new Properties )
15769 val rows = df.collect()
15870 assert(rows.length == 1 )
15971 val types = rows(0 ).toSeq.map(x => x.getClass.toString)
@@ -180,7 +92,7 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with Sh
18092 }
18193
18294 test(" Date types" ) {
183- val df = sqlContext.read.jdbc(url(ip, " foo " ) , " dates" , new Properties )
95+ val df = sqlContext.read.jdbc(db.jdbcUrl , " dates" , new Properties )
18496 val rows = df.collect()
18597 assert(rows.length == 1 )
18698 val types = rows(0 ).toSeq.map(x => x.getClass.toString)
@@ -198,7 +110,7 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with Sh
198110 }
199111
200112 test(" String types" ) {
201- val df = sqlContext.read.jdbc(url(ip, " foo " ) , " strings" , new Properties )
113+ val df = sqlContext.read.jdbc(db.jdbcUrl , " strings" , new Properties )
202114 val rows = df.collect()
203115 assert(rows.length == 1 )
204116 val types = rows(0 ).toSeq.map(x => x.getClass.toString)
@@ -224,11 +136,11 @@ class MySQLIntegrationSuite extends SparkFunSuite with BeforeAndAfterAll with Sh
224136 }
225137
226138 test(" Basic write test" ) {
227- val df1 = sqlContext.read.jdbc(url(ip, " foo " ) , " numbers" , new Properties )
228- val df2 = sqlContext.read.jdbc(url(ip, " foo " ) , " dates" , new Properties )
229- val df3 = sqlContext.read.jdbc(url(ip, " foo " ) , " strings" , new Properties )
230- df1.write.jdbc(url(ip, " foo " ) , " numberscopy" , new Properties )
231- df2.write.jdbc(url(ip, " foo " ) , " datescopy" , new Properties )
232- df3.write.jdbc(url(ip, " foo " ) , " stringscopy" , new Properties )
139+ val df1 = sqlContext.read.jdbc(db.jdbcUrl , " numbers" , new Properties )
140+ val df2 = sqlContext.read.jdbc(db.jdbcUrl , " dates" , new Properties )
141+ val df3 = sqlContext.read.jdbc(db.jdbcUrl , " strings" , new Properties )
142+ df1.write.jdbc(db.jdbcUrl , " numberscopy" , new Properties )
143+ df2.write.jdbc(db.jdbcUrl , " datescopy" , new Properties )
144+ df3.write.jdbc(db.jdbcUrl , " stringscopy" , new Properties )
233145 }
234146}
0 commit comments