@@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc
1919
2020import  java .sql .{Connection , Date , Timestamp }
2121import  java .util .Properties 
22+ import  java .math .BigDecimal 
2223
2324import  org .apache .spark .sql .Row 
2425import  org .apache .spark .sql .test .SharedSQLContext 
@@ -93,8 +94,31 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
9394        |USING org.apache.spark.sql.jdbc 
9495        |OPTIONS (url ' $jdbcUrl', dbTable 'datetime1', oracle.jdbc.mapDateToTimestamp 'false') 
9596       """ .stripMargin.replaceAll(" \n " "  " 
97+ 
98+ 
99+     conn.prepareStatement(" CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))" 
100+     conn.prepareStatement(
101+       " INSERT INTO numerics VALUES (4, 1.23, 9999999999)" 
102+     conn.commit();
96103  }
97104
105+ 
106+   test(" SPARK-16625 : Importing Oracle numeric types" 
107+     val  df  =  sqlContext.read.jdbc(jdbcUrl, " numerics" new  Properties );
108+     val  rows  =  df.collect()
109+     assert(rows.size ==  1 )
110+     val  row  =  rows(0 )
111+     //  The main point of the below assertions is not to make sure that these Oracle types are
112+     //  mapped to decimal types, but to make sure that the returned values are correct.
113+     //  A value > 1 from DECIMAL(1) is correct:
114+     assert(row.getDecimal(0 ).compareTo(BigDecimal .valueOf(4 )) ==  0 )
115+     //  A value with fractions from DECIMAL(3, 2) is correct:
116+     assert(row.getDecimal(1 ).compareTo(BigDecimal .valueOf(1.23 )) ==  0 )
117+     //  A value > Int.MaxValue from DECIMAL(10) is correct:
118+     assert(row.getDecimal(2 ).compareTo(BigDecimal .valueOf(9999999999l )) ==  0 )
119+   }
120+ 
121+ 
98122  test(" SPARK-12941: String datatypes to be mapped to Varchar in Oracle" 
99123    //  create a sample dataframe with string type
100124    val  df1  =  sparkContext.parallelize(Seq ((" foo" " x" 
@@ -154,27 +178,28 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
154178    val  dfRead  =  spark.read.jdbc(jdbcUrl, tableName, props)
155179    val  rows  =  dfRead.collect()
156180    //  verify the data type is inserted
157-     val  types  =  rows(0 ).toSeq.map(x =>  x.getClass.toString)
158-     assert(types(0 ).equals(" class java.lang.Boolean" 
159-     assert(types(1 ).equals(" class java.lang.Integer" 
160-     assert(types(2 ).equals(" class java.lang.Long" 
161-     assert(types(3 ).equals(" class java.lang.Float" 
162-     assert(types(4 ).equals(" class java.lang.Float" 
163-     assert(types(5 ).equals(" class java.lang.Integer" 
164-     assert(types(6 ).equals(" class java.lang.Integer" 
165-     assert(types(7 ).equals(" class java.lang.String" 
166-     assert(types(8 ).equals(" class [B" 
167-     assert(types(9 ).equals(" class java.sql.Date" 
168-     assert(types(10 ).equals(" class java.sql.Timestamp" 
181+     val  types  =  dfRead.schema.map(field =>  field.dataType)
182+     assert(types(0 ).equals(DecimalType (1 , 0 )))
183+     assert(types(1 ).equals(DecimalType (10 , 0 )))
184+     assert(types(2 ).equals(DecimalType (19 , 0 )))
185+     assert(types(3 ).equals(DecimalType (19 , 4 )))
186+     assert(types(4 ).equals(DecimalType (19 , 4 )))
187+     assert(types(5 ).equals(DecimalType (3 , 0 )))
188+     assert(types(6 ).equals(DecimalType (5 , 0 )))
189+     assert(types(7 ).equals(StringType ))
190+     assert(types(8 ).equals(BinaryType ))
191+     assert(types(9 ).equals(DateType ))
192+     assert(types(10 ).equals(TimestampType ))
193+ 
169194    //  verify the value is the inserted correct or not
170195    val  values  =  rows(0 )
171-     assert(values.getBoolean (0 ).equals(booleanVal) )
172-     assert(values.getInt (1 ).equals( integerVal))
173-     assert(values.getLong (2 ).equals( longVal))
174-     assert(values.getFloat (3 ).equals( floatVal))
175-     assert(values.getFloat (4 ).equals(doubleVal.toFloat) )
176-     assert(values.getInt (5 ).equals(byteVal.toInt) )
177-     assert(values.getInt (6 ).equals(shortVal.toInt) )
196+     assert(values.getDecimal (0 ).compareTo( BigDecimal .valueOf( 1 ))  ==   0 )
197+     assert(values.getDecimal (1 ).compareTo( BigDecimal .valueOf( integerVal))  ==   0 )
198+     assert(values.getDecimal (2 ).compareTo( BigDecimal .valueOf( longVal))  ==   0 )
199+     assert(values.getDecimal (3 ).compareTo( BigDecimal .valueOf( floatVal))  ==   0 )
200+     assert(values.getDecimal (4 ).compareTo( BigDecimal .valueOf(doubleVal))  ==   0 )
201+     assert(values.getDecimal (5 ).compareTo( BigDecimal .valueOf(byteVal))  ==   0 )
202+     assert(values.getDecimal (6 ).compareTo( BigDecimal .valueOf(shortVal))  ==   0 )
178203    assert(values.getString(7 ).equals(stringVal))
179204    assert(values.getAs[Array [Byte ]](8 ).mkString.equals(" 678" 
180205    assert(values.getDate(9 ).equals(dateVal))
@@ -183,7 +208,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
183208
184209  test(" SPARK-19318: connection property keys should be case-sensitive" 
185210    def  checkRow (row : Row ):  Unit  =  {
186-       assert(row.getInt (0 )  ==   1 )
211+       assert(row.getDecimal (0 ).equals( BigDecimal .valueOf( 1 )) )
187212      assert(row.getDate(1 ).equals(Date .valueOf(" 1991-11-09" 
188213      assert(row.getTimestamp(2 ).equals(Timestamp .valueOf(" 1996-01-01 01:23:45" 
189214    }
0 commit comments