diff --git a/.rat-excludes b/.rat-excludes index 9165872b9fb27..08fba6d351d6a 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -15,20 +15,8 @@ TAGS RELEASE control docs -docker.properties.template -fairscheduler.xml.template -spark-defaults.conf.template -log4j.properties -log4j.properties.template -metrics.properties -metrics.properties.template slaves -slaves.template -spark-env.sh spark-env.cmd -spark-env.sh.template -log4j-defaults.properties -log4j-defaults-repl.properties bootstrap-tooltip.js jquery-1.11.1.min.js d3.min.js diff --git a/LICENSE b/LICENSE index f9e412cade345..790476ece15bd 100644 --- a/LICENSE +++ b/LICENSE @@ -211,712 +211,45 @@ subcomponents is subject to the terms and conditions of the following licenses. -======================================================================= -For the Boto EC2 library (ec2/third_party/boto*.zip): -======================================================================= - -Copyright (c) 2006-2008 Mitch Garnaat http://garnaat.org/ - -Permission is hereby granted, free of charge, to any person obtaining a -copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, dis- -tribute, sublicense, and/or sell copies of the Software, and to permit -persons to whom the Software is furnished to do so, subject to the fol- -lowing conditions: - -The above copyright notice and this permission notice shall be included -in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL- -ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -IN THE SOFTWARE. - - -======================================================================== -For CloudPickle (pyspark/cloudpickle.py): -======================================================================== - -Copyright (c) 2012, Regents of the University of California. -Copyright (c) 2009 `PiCloud, Inc. `_. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions -are met: - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - * Neither the name of the University of California, Berkeley nor the - names of its contributors may be used to endorse or promote - products derived from this software without specific prior written - permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED -TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -======================================================================== -For Py4J (python/lib/py4j-0.8.2.1-src.zip) -======================================================================== - -Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -- Redistributions of source code must retain the above copyright notice, this -list of conditions and the following disclaimer. - -- Redistributions in binary form must reproduce the above copyright notice, -this list of conditions and the following disclaimer in the documentation -and/or other materials provided with the distribution. - -- The name of the author may not be used to endorse or promote products -derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. - - -======================================================================== -For DPark join code (python/pyspark/join.py): -======================================================================== - -Copyright (c) 2011, Douban Inc. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - - * Neither the name of the Douban Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - ======================================================================== For heapq (pyspark/heapq3.py): ======================================================================== -# A. HISTORY OF THE SOFTWARE -# ========================== -# -# Python was created in the early 1990s by Guido van Rossum at Stichting -# Mathematisch Centrum (CWI, see http://www.cwi.nl) in the Netherlands -# as a successor of a language called ABC. Guido remains Python's -# principal author, although it includes many contributions from others. -# -# In 1995, Guido continued his work on Python at the Corporation for -# National Research Initiatives (CNRI, see http://www.cnri.reston.va.us) -# in Reston, Virginia where he released several versions of the -# software. -# -# In May 2000, Guido and the Python core development team moved to -# BeOpen.com to form the BeOpen PythonLabs team. In October of the same -# year, the PythonLabs team moved to Digital Creations (now Zope -# Corporation, see http://www.zope.com). In 2001, the Python Software -# Foundation (PSF, see http://www.python.org/psf/) was formed, a -# non-profit organization created specifically to own Python-related -# Intellectual Property. Zope Corporation is a sponsoring member of -# the PSF. -# -# All Python releases are Open Source (see http://www.opensource.org for -# the Open Source Definition). Historically, most, but not all, Python -# releases have also been GPL-compatible; the table below summarizes -# the various releases. -# -# Release Derived Year Owner GPL- -# from compatible? (1) -# -# 0.9.0 thru 1.2 1991-1995 CWI yes -# 1.3 thru 1.5.2 1.2 1995-1999 CNRI yes -# 1.6 1.5.2 2000 CNRI no -# 2.0 1.6 2000 BeOpen.com no -# 1.6.1 1.6 2001 CNRI yes (2) -# 2.1 2.0+1.6.1 2001 PSF no -# 2.0.1 2.0+1.6.1 2001 PSF yes -# 2.1.1 2.1+2.0.1 2001 PSF yes -# 2.2 2.1.1 2001 PSF yes -# 2.1.2 2.1.1 2002 PSF yes -# 2.1.3 2.1.2 2002 PSF yes -# 2.2.1 2.2 2002 PSF yes -# 2.2.2 2.2.1 2002 PSF yes -# 2.2.3 2.2.2 2003 PSF yes -# 2.3 2.2.2 2002-2003 PSF yes -# 2.3.1 2.3 2002-2003 PSF yes -# 2.3.2 2.3.1 2002-2003 PSF yes -# 2.3.3 2.3.2 2002-2003 PSF yes -# 2.3.4 2.3.3 2004 PSF yes -# 2.3.5 2.3.4 2005 PSF yes -# 2.4 2.3 2004 PSF yes -# 2.4.1 2.4 2005 PSF yes -# 2.4.2 2.4.1 2005 PSF yes -# 2.4.3 2.4.2 2006 PSF yes -# 2.4.4 2.4.3 2006 PSF yes -# 2.5 2.4 2006 PSF yes -# 2.5.1 2.5 2007 PSF yes -# 2.5.2 2.5.1 2008 PSF yes -# 2.5.3 2.5.2 2008 PSF yes -# 2.6 2.5 2008 PSF yes -# 2.6.1 2.6 2008 PSF yes -# 2.6.2 2.6.1 2009 PSF yes -# 2.6.3 2.6.2 2009 PSF yes -# 2.6.4 2.6.3 2009 PSF yes -# 2.6.5 2.6.4 2010 PSF yes -# 2.7 2.6 2010 PSF yes -# -# Footnotes: -# -# (1) GPL-compatible doesn't mean that we're distributing Python under -# the GPL. All Python licenses, unlike the GPL, let you distribute -# a modified version without making your changes open source. The -# GPL-compatible licenses make it possible to combine Python with -# other software that is released under the GPL; the others don't. -# -# (2) According to Richard Stallman, 1.6.1 is not GPL-compatible, -# because its license has a choice of law clause. According to -# CNRI, however, Stallman's lawyer has told CNRI's lawyer that 1.6.1 -# is "not incompatible" with the GPL. -# -# Thanks to the many outside volunteers who have worked under Guido's -# direction to make these releases possible. -# -# -# B. TERMS AND CONDITIONS FOR ACCESSING OR OTHERWISE USING PYTHON -# =============================================================== -# -# PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 -# -------------------------------------------- -# -# 1. This LICENSE AGREEMENT is between the Python Software Foundation -# ("PSF"), and the Individual or Organization ("Licensee") accessing and -# otherwise using this software ("Python") in source or binary form and -# its associated documentation. -# -# 2. Subject to the terms and conditions of this License Agreement, PSF hereby -# grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, -# analyze, test, perform and/or display publicly, prepare derivative works, -# distribute, and otherwise use Python alone or in any derivative version, -# provided, however, that PSF's License Agreement and PSF's notice of copyright, -# i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, -# 2011, 2012, 2013 Python Software Foundation; All Rights Reserved" are retained -# in Python alone or in any derivative version prepared by Licensee. -# -# 3. In the event Licensee prepares a derivative work that is based on -# or incorporates Python or any part thereof, and wants to make -# the derivative work available to others as provided herein, then -# Licensee hereby agrees to include in any such work a brief summary of -# the changes made to Python. -# -# 4. PSF is making Python available to Licensee on an "AS IS" -# basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR -# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND -# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS -# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT -# INFRINGE ANY THIRD PARTY RIGHTS. -# -# 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON -# FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS -# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, -# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. -# -# 6. This License Agreement will automatically terminate upon a material -# breach of its terms and conditions. -# -# 7. Nothing in this License Agreement shall be deemed to create any -# relationship of agency, partnership, or joint venture between PSF and -# Licensee. This License Agreement does not grant permission to use PSF -# trademarks or trade name in a trademark sense to endorse or promote -# products or services of Licensee, or any third party. -# -# 8. By copying, installing or otherwise using Python, Licensee -# agrees to be bound by the terms and conditions of this License -# Agreement. -# -# -# BEOPEN.COM LICENSE AGREEMENT FOR PYTHON 2.0 -# ------------------------------------------- -# -# BEOPEN PYTHON OPEN SOURCE LICENSE AGREEMENT VERSION 1 -# -# 1. This LICENSE AGREEMENT is between BeOpen.com ("BeOpen"), having an -# office at 160 Saratoga Avenue, Santa Clara, CA 95051, and the -# Individual or Organization ("Licensee") accessing and otherwise using -# this software in source or binary form and its associated -# documentation ("the Software"). -# -# 2. Subject to the terms and conditions of this BeOpen Python License -# Agreement, BeOpen hereby grants Licensee a non-exclusive, -# royalty-free, world-wide license to reproduce, analyze, test, perform -# and/or display publicly, prepare derivative works, distribute, and -# otherwise use the Software alone or in any derivative version, -# provided, however, that the BeOpen Python License is retained in the -# Software, alone or in any derivative version prepared by Licensee. -# -# 3. BeOpen is making the Software available to Licensee on an "AS IS" -# basis. BEOPEN MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR -# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, BEOPEN MAKES NO AND -# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS -# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF THE SOFTWARE WILL NOT -# INFRINGE ANY THIRD PARTY RIGHTS. -# -# 4. BEOPEN SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF THE -# SOFTWARE FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS -# AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE, OR ANY -# DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. -# -# 5. This License Agreement will automatically terminate upon a material -# breach of its terms and conditions. -# -# 6. This License Agreement shall be governed by and interpreted in all -# respects by the law of the State of California, excluding conflict of -# law provisions. Nothing in this License Agreement shall be deemed to -# create any relationship of agency, partnership, or joint venture -# between BeOpen and Licensee. This License Agreement does not grant -# permission to use BeOpen trademarks or trade names in a trademark -# sense to endorse or promote products or services of Licensee, or any -# third party. As an exception, the "BeOpen Python" logos available at -# http://www.pythonlabs.com/logos.html may be used according to the -# permissions granted on that web page. -# -# 7. By copying, installing or otherwise using the software, Licensee -# agrees to be bound by the terms and conditions of this License -# Agreement. -# -# -# CNRI LICENSE AGREEMENT FOR PYTHON 1.6.1 -# --------------------------------------- -# -# 1. This LICENSE AGREEMENT is between the Corporation for National -# Research Initiatives, having an office at 1895 Preston White Drive, -# Reston, VA 20191 ("CNRI"), and the Individual or Organization -# ("Licensee") accessing and otherwise using Python 1.6.1 software in -# source or binary form and its associated documentation. -# -# 2. Subject to the terms and conditions of this License Agreement, CNRI -# hereby grants Licensee a nonexclusive, royalty-free, world-wide -# license to reproduce, analyze, test, perform and/or display publicly, -# prepare derivative works, distribute, and otherwise use Python 1.6.1 -# alone or in any derivative version, provided, however, that CNRI's -# License Agreement and CNRI's notice of copyright, i.e., "Copyright (c) -# 1995-2001 Corporation for National Research Initiatives; All Rights -# Reserved" are retained in Python 1.6.1 alone or in any derivative -# version prepared by Licensee. Alternately, in lieu of CNRI's License -# Agreement, Licensee may substitute the following text (omitting the -# quotes): "Python 1.6.1 is made available subject to the terms and -# conditions in CNRI's License Agreement. This Agreement together with -# Python 1.6.1 may be located on the Internet using the following -# unique, persistent identifier (known as a handle): 1895.22/1013. This -# Agreement may also be obtained from a proxy server on the Internet -# using the following URL: http://hdl.handle.net/1895.22/1013". -# -# 3. In the event Licensee prepares a derivative work that is based on -# or incorporates Python 1.6.1 or any part thereof, and wants to make -# the derivative work available to others as provided herein, then -# Licensee hereby agrees to include in any such work a brief summary of -# the changes made to Python 1.6.1. -# -# 4. CNRI is making Python 1.6.1 available to Licensee on an "AS IS" -# basis. CNRI MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR -# IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, CNRI MAKES NO AND -# DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS -# FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON 1.6.1 WILL NOT -# INFRINGE ANY THIRD PARTY RIGHTS. -# -# 5. CNRI SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON -# 1.6.1 FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS -# A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON 1.6.1, -# OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. -# -# 6. This License Agreement will automatically terminate upon a material -# breach of its terms and conditions. -# -# 7. This License Agreement shall be governed by the federal -# intellectual property law of the United States, including without -# limitation the federal copyright law, and, to the extent such -# U.S. federal law does not apply, by the law of the Commonwealth of -# Virginia, excluding Virginia's conflict of law provisions. -# Notwithstanding the foregoing, with regard to derivative works based -# on Python 1.6.1 that incorporate non-separable material that was -# previously distributed under the GNU General Public License (GPL), the -# law of the Commonwealth of Virginia shall govern this License -# Agreement only as to issues arising under or with respect to -# Paragraphs 4, 5, and 7 of this License Agreement. Nothing in this -# License Agreement shall be deemed to create any relationship of -# agency, partnership, or joint venture between CNRI and Licensee. This -# License Agreement does not grant permission to use CNRI trademarks or -# trade name in a trademark sense to endorse or promote products or -# services of Licensee, or any third party. -# -# 8. By clicking on the "ACCEPT" button where indicated, or by copying, -# installing or otherwise using Python 1.6.1, Licensee agrees to be -# bound by the terms and conditions of this License Agreement. -# -# ACCEPT -# -# -# CWI LICENSE AGREEMENT FOR PYTHON 0.9.0 THROUGH 1.2 -# -------------------------------------------------- -# -# Copyright (c) 1991 - 1995, Stichting Mathematisch Centrum Amsterdam, -# The Netherlands. All rights reserved. -# -# Permission to use, copy, modify, and distribute this software and its -# documentation for any purpose and without fee is hereby granted, -# provided that the above copyright notice appear in all copies and that -# both that copyright notice and this permission notice appear in -# supporting documentation, and that the name of Stichting Mathematisch -# Centrum or CWI not be used in advertising or publicity pertaining to -# distribution of the software without specific, written prior -# permission. -# -# STICHTING MATHEMATISCH CENTRUM DISCLAIMS ALL WARRANTIES WITH REGARD TO -# THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND -# FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE -# FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT -# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - -======================================================================== -For sorttable (core/src/main/resources/org/apache/spark/ui/static/sorttable.js): -======================================================================== - -Copyright (c) 1997-2007 Stuart Langridge - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. - -======================================================================== -For d3 (core/src/main/resources/org/apache/spark/ui/static/d3.min.js): -======================================================================== - -Copyright (c) 2010-2015, Michael Bostock -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -* The name Michael Bostock may not be used to endorse or promote products - derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, -INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, -EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -======================================================================== -For Scala Interpreter classes (all .scala files in repl/src/main/scala -except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), -and for SerializableMapWrapper in JavaUtils.scala: -======================================================================== - -Copyright (c) 2002-2013 EPFL -Copyright (c) 2011-2013 Typesafe, Inc. - -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -- Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - -- Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -- Neither the name of the EPFL nor the names of its contributors may be - used to endorse or promote products derived from this software without - specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. - - -======================================================================== -For sbt and sbt-launch-lib.bash in sbt/: -======================================================================== - -// Generated from http://www.opensource.org/licenses/bsd-license.php -Copyright (c) 2011, Paul Phillips. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - * Neither the name of the author nor the names of its contributors may be - used to endorse or promote products derived from this software without - specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, -EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +See license/LICENSE-heapq.txt ======================================================================== For SnapTree: ======================================================================== -SNAPTREE LICENSE - -Copyright (c) 2009-2012 Stanford University, unless otherwise specified. -All rights reserved. - -This software was developed by the Pervasive Parallelism Laboratory of -Stanford University, California, USA. - -Permission to use, copy, modify, and distribute this software in source -or binary form for any purpose with or without fee is hereby granted, -provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - 3. Neither the name of Stanford University nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - -THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF -SUCH DAMAGE. - - -======================================================================== -For Timsort (core/src/main/java/org/apache/spark/util/collection/TimSort.java): -======================================================================== -Copyright (C) 2008 The Android Open Source Project - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -======================================================================== -For TestTimSort (core/src/test/java/org/apache/spark/util/collection/TestTimSort.java): -======================================================================== -Copyright (C) 2015 Stijn de Gouw - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -======================================================================== -For LimitedInputStream - (network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java): -======================================================================== -Copyright (C) 2007 The Guava Authors - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -======================================================================== -For vis.js (core/src/main/resources/org/apache/spark/ui/static/vis.min.js): -======================================================================== -Copyright (C) 2010-2015 Almende B.V. - -Vis.js is dual licensed under both - - * The Apache 2.0 License - http://www.apache.org/licenses/LICENSE-2.0 - -and - - * The MIT License - http://opensource.org/licenses/MIT - -Vis.js may be distributed under either license. +See license/LICENSE-SnapTree.txt ======================================================================== -For dagre-d3 (core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js): +For jbcrypt: ======================================================================== -Copyright (c) 2013 Chris Pettitt - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -======================================================================== -For graphlib-dot (core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js): -======================================================================== -Copyright (c) 2012-2013 Chris Pettitt - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. +See license/LICENSE-jbcrypt.txt ======================================================================== BSD-style licenses ======================================================================== The following components are provided under a BSD-style license. See project link for details. +The text of each license is also included at licenses/LICENSE-[project].txt. - (BSD 3 Clause) core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) + (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model) (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.4 - http://jblas.org/) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) - (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) (BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org) - (BSD style) Hamcrest Core (org.hamcrest:hamcrest-core:1.1 - no url defined) + (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) (BSD) JLine (jline:jline:0.9.94 - http://jline.sourceforge.net) (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.3 - http://paranamer.codehaus.org/paranamer) (BSD) ParaNamer Core (com.thoughtworks.paranamer:paranamer:2.6 - http://paranamer.codehaus.org/paranamer) - (BSD-like) (The BSD License) jline (org.scala-lang:jline:2.10.4 - http://www.scala-lang.org/) + (BSD 3 Clause) Scala (http://www.scala-lang.org/download/#License) + (Interpreter classes (all .scala files in repl/src/main/scala + except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), + and for SerializableMapWrapper in JavaUtils.scala) (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.10.4 - http://www.scala-lang.org/) (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.10.4 - http://www.scala-lang.org/) (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.10.4 - http://www.scala-lang.org/) @@ -932,15 +265,19 @@ The following components are provided under a BSD-style license. See project lin (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.8.2.1 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.9 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) - (ISC/BSD License) jbcrypt (org.mindrot:jbcrypt:0.3m - http://www.mindrot.org/) + (BSD licence) sbt and sbt-launch-lib.bash + (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) + (BSD 3 Clause) DPark (https://github.com/douban/dpark/blob/master/LICENSE) + (BSD 3 Clause) CloudPickle (https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE) ======================================================================== MIT licenses ======================================================================== The following components are provided under the MIT License. See project link for details. +The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) JCL 1.1.1 implemented over SLF4J (org.slf4j:jcl-over-slf4j:1.7.5 - http://www.slf4j.org) (MIT License) JUL to SLF4J bridge (org.slf4j:jul-to-slf4j:1.7.5 - http://www.slf4j.org) @@ -951,3 +288,7 @@ The following components are provided under the MIT License. See project link fo (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) (MIT License) jquery (https://jquery.org/license/) (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) + (MIT License) graphlib-dot (https://github.com/cpettitt/graphlib-dot) + (MIT License) dagre-d3 (https://github.com/cpettitt/dagre-d3) + (MIT License) sorttable (https://github.com/stuartlangridge/sorttable) + (MIT License) boto (https://github.com/boto/boto/blob/develop/LICENSE) diff --git a/NOTICE b/NOTICE index 452aef2871652..7f7769f73047f 100644 --- a/NOTICE +++ b/NOTICE @@ -572,3 +572,38 @@ Copyright 2009-2013 The Apache Software Foundation Apache Avro IPC Copyright 2009-2013 The Apache Software Foundation + + +Vis.js +Copyright 2010-2015 Almende B.V. + +Vis.js is dual licensed under both + + * The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0 + + and + + * The MIT License + http://opensource.org/licenses/MIT + +Vis.js may be distributed under either license. + + +Vis.js uses and redistributes the following third-party libraries: + +- component-emitter + https://github.com/component/emitter + The MIT License + +- hammer.js + http://hammerjs.github.io/ + The MIT License + +- moment.js + http://momentjs.com/ + The MIT License + +- keycharm + https://github.com/AlexDM0/keycharm + The MIT License \ No newline at end of file diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index d0d7201f004a2..3d6edb70ec98e 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,7 +1,7 @@ Package: SparkR Type: Package Title: R frontend for Spark -Version: 1.5.0 +Version: 1.6.0 Date: 2013-09-09 Author: The Apache Software Foundation Maintainer: Shivaram Venkataraman @@ -33,4 +33,5 @@ Collate: 'mllib.R' 'serialize.R' 'sparkR.R' + 'stats.R' 'utils.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 9d39630706436..52f7a0106aae6 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -23,10 +23,13 @@ export("setJobGroup", exportClasses("DataFrame") exportMethods("arrange", + "attach", "cache", "collect", "columns", "count", + "cov", + "corr", "crosstab", "describe", "dim", @@ -38,6 +41,7 @@ exportMethods("arrange", "fillna", "filter", "first", + "freqItems", "group_by", "groupBy", "head", @@ -61,6 +65,7 @@ exportMethods("arrange", "repartition", "sample", "sample_frac", + "sampleBy", "saveAsParquetFile", "saveAsTable", "saveDF", @@ -104,6 +109,7 @@ exportMethods("%in%", "cbrt", "ceil", "ceiling", + "column", "concat", "concat_ws", "contains", @@ -224,7 +230,8 @@ exportMethods("agg") export("sparkRSQL.init", "sparkRHive.init") -export("cacheTable", +export("as.DataFrame", + "cacheTable", "clearCache", "createDataFrame", "createExternalTable", @@ -247,3 +254,5 @@ export("structField", "structType.jobj", "structType.structField", "print.structType") + +export("as.data.frame") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 8a00238b41d60..2acbd081cd504 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -271,7 +271,7 @@ setMethod("names<-", signature(x = "DataFrame"), function(x, value) { if (!is.null(value)) { - sdf <- callJMethod(x@sdf, "toDF", listToSeq(as.list(value))) + sdf <- callJMethod(x@sdf, "toDF", as.list(value)) dataFrame(sdf) } }) @@ -843,10 +843,10 @@ setMethod("groupBy", function(x, ...) { cols <- list(...) if (length(cols) >= 1 && class(cols[[1]]) == "character") { - sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1])) + sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], cols[-1]) } else { jcol <- lapply(cols, function(c) { c@jc }) - sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol)) + sgd <- callJMethod(x@sdf, "groupBy", jcol) } groupedData(sgd) }) @@ -1075,12 +1075,20 @@ setMethod("subset", signature(x = "DataFrame"), #' select(df, c("col1", "col2")) #' select(df, list(df$name, df$age + 1)) #' # Similar to R data frames columns can also be selected using `$` -#' df$age +#' df[,df$age] #' } setMethod("select", signature(x = "DataFrame", col = "character"), function(x, col, ...) { - sdf <- callJMethod(x@sdf, "select", col, toSeq(...)) - dataFrame(sdf) + if (length(col) > 1) { + if (length(list(...)) > 0) { + stop("To select multiple columns, use a character vector or list for col") + } + + select(x, as.list(col)) + } else { + sdf <- callJMethod(x@sdf, "select", col, list(...)) + dataFrame(sdf) + } }) #' @rdname select @@ -1090,7 +1098,7 @@ setMethod("select", signature(x = "DataFrame", col = "Column"), jcols <- lapply(list(col, ...), function(c) { c@jc }) - sdf <- callJMethod(x@sdf, "select", listToSeq(jcols)) + sdf <- callJMethod(x@sdf, "select", jcols) dataFrame(sdf) }) @@ -1106,7 +1114,7 @@ setMethod("select", col(c)@jc } }) - sdf <- callJMethod(x@sdf, "select", listToSeq(cols)) + sdf <- callJMethod(x@sdf, "select", cols) dataFrame(sdf) }) @@ -1133,7 +1141,7 @@ setMethod("selectExpr", signature(x = "DataFrame", expr = "character"), function(x, expr, ...) { exprList <- list(expr, ...) - sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList)) + sdf <- callJMethod(x@sdf, "selectExpr", exprList) dataFrame(sdf) }) @@ -1290,8 +1298,10 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' Sort a DataFrame by the specified column(s). #' #' @param x A DataFrame to be sorted. -#' @param col Either a Column object or character vector indicating the field to sort on +#' @param col A character or Column object vector indicating the fields to sort on #' @param ... Additional sorting fields +#' @param decreasing A logical argument indicating sorting order for columns when +#' a character vector is specified for col #' @return A DataFrame where all elements are sorted. #' @rdname arrange #' @name arrange @@ -1304,23 +1314,52 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' path <- "path/to/file.json" #' df <- jsonFile(sqlContext, path) #' arrange(df, df$col1) -#' arrange(df, "col1") #' arrange(df, asc(df$col1), desc(abs(df$col2))) +#' arrange(df, "col1", decreasing = TRUE) +#' arrange(df, "col1", "col2", decreasing = c(TRUE, FALSE)) #' } setMethod("arrange", - signature(x = "DataFrame", col = "characterOrColumn"), + signature(x = "DataFrame", col = "Column"), function(x, col, ...) { - if (class(col) == "character") { - sdf <- callJMethod(x@sdf, "sort", col, toSeq(...)) - } else if (class(col) == "Column") { jcols <- lapply(list(col, ...), function(c) { c@jc }) - sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols)) - } + + sdf <- callJMethod(x@sdf, "sort", jcols) dataFrame(sdf) }) +#' @rdname arrange +#' @export +setMethod("arrange", + signature(x = "DataFrame", col = "character"), + function(x, col, ..., decreasing = FALSE) { + + # all sorting columns + by <- list(col, ...) + + if (length(decreasing) == 1) { + # in case only 1 boolean argument - decreasing value is specified, + # it will be used for all columns + decreasing <- rep(decreasing, length(by)) + } else if (length(decreasing) != length(by)) { + stop("Arguments 'col' and 'decreasing' must have the same length") + } + + # builds a list of columns of type Column + # example: [[1]] Column Species ASC + # [[2]] Column Petal_Length DESC + jcols <- lapply(seq_len(length(decreasing)), function(i){ + if (decreasing[[i]]) { + desc(getColumn(x, by[[i]])) + } else { + asc(getColumn(x, by[[i]])) + } + }) + + do.call("arrange", c(x, jcols)) + }) + #' @rdname arrange #' @name orderby setMethod("orderBy", @@ -1375,9 +1414,10 @@ setMethod("where", #' @param x A Spark DataFrame #' @param y A Spark DataFrame #' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a -#' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join +#' Column expression. If joinExpr is omitted, join() will perform a Cartesian join #' @param joinType The type of join to perform. The following join types are available: -#' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner". +#' 'inner', 'outer', 'full', 'fullouter', leftouter', 'left_outer', 'left', +#' 'right_outer', 'rightouter', 'right', and 'leftsemi'. The default joinType is "inner". #' @return A DataFrame containing the result of the join operation. #' @rdname join #' @name join @@ -1402,11 +1442,15 @@ setMethod("join", if (is.null(joinType)) { sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc) } else { - if (joinType %in% c("inner", "outer", "left_outer", "right_outer", "semijoin")) { + if (joinType %in% c("inner", "outer", "full", "fullouter", + "leftouter", "left_outer", "left", + "rightouter", "right_outer", "right", "leftsemi")) { + joinType <- gsub("_", "", joinType) sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc, joinType) } else { stop("joinType must be one of the following types: ", - "'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'") + "'inner', 'outer', 'full', 'fullouter', 'leftouter', 'left_outer', 'left', + 'rightouter', 'right_outer', 'right', 'leftsemi'") } } } @@ -1528,18 +1572,17 @@ setMethod("except", #' spark.sql.sources.default will be used. #' #' Additionally, mode is used to specify the behavior of the save operation when -#' data already exists in the data source. There are four modes: -#' append: Contents of this DataFrame are expected to be appended to existing data. -#' overwrite: Existing data is expected to be overwritten by the contents of -# this DataFrame. -#' error: An exception is expected to be thrown. +#' data already exists in the data source. There are four modes: \cr +#' append: Contents of this DataFrame are expected to be appended to existing data. \cr +#' overwrite: Existing data is expected to be overwritten by the contents of this DataFrame. \cr +#' error: An exception is expected to be thrown. \cr #' ignore: The save operation is expected to not save the contents of the DataFrame -# and to not change the existing data. +#' and to not change the existing data. \cr #' #' @param df A SparkSQL DataFrame #' @param path A name for the table #' @param source A name for external data source -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode #' #' @rdname write.df #' @name write.df @@ -1552,6 +1595,7 @@ setMethod("except", #' path <- "path/to/file.json" #' df <- jsonFile(sqlContext, path) #' write.df(df, "myfile", "parquet", "overwrite") +#' saveDF(df, parquetPath2, "parquet", mode = saveMode, mergeSchema = mergeSchema) #' } setMethod("write.df", signature(df = "DataFrame", path = "character"), @@ -1593,18 +1637,17 @@ setMethod("saveDF", #' spark.sql.sources.default will be used. #' #' Additionally, mode is used to specify the behavior of the save operation when -#' data already exists in the data source. There are four modes: -#' append: Contents of this DataFrame are expected to be appended to existing data. -#' overwrite: Existing data is expected to be overwritten by the contents of -# this DataFrame. -#' error: An exception is expected to be thrown. +#' data already exists in the data source. There are four modes: \cr +#' append: Contents of this DataFrame are expected to be appended to existing data. \cr +#' overwrite: Existing data is expected to be overwritten by the contents of this DataFrame. \cr +#' error: An exception is expected to be thrown. \cr #' ignore: The save operation is expected to not save the contents of the DataFrame -# and to not change the existing data. +#' and to not change the existing data. \cr #' #' @param df A SparkSQL DataFrame #' @param tableName A name for the table #' @param source A name for external data source -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode #' #' @rdname saveAsTable #' @name saveAsTable @@ -1664,7 +1707,7 @@ setMethod("describe", signature(x = "DataFrame", col = "character"), function(x, col, ...) { colList <- list(col, ...) - sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) + sdf <- callJMethod(x@sdf, "describe", colList) dataFrame(sdf) }) @@ -1674,7 +1717,7 @@ setMethod("describe", signature(x = "DataFrame"), function(x) { colList <- as.list(c(columns(x))) - sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) + sdf <- callJMethod(x@sdf, "describe", colList) dataFrame(sdf) }) @@ -1731,7 +1774,7 @@ setMethod("dropna", naFunctions <- callJMethod(x@sdf, "na") sdf <- callJMethod(naFunctions, "drop", - as.integer(minNonNulls), listToSeq(as.list(cols))) + as.integer(minNonNulls), as.list(cols)) dataFrame(sdf) }) @@ -1787,17 +1830,15 @@ setMethod("fillna", if (length(colNames) == 0 || !all(colNames != "")) { stop("value should be an a named list with each name being a column name.") } - - # Convert to the named list to an environment to be passed to JVM - valueMap <- new.env() - for (col in colNames) { - # Check each item in the named list is of valid type - v <- value[[col]] + # Check each item in the named list is of valid type + lapply(value, function(v) { if (!(class(v) %in% c("integer", "numeric", "character"))) { stop("Each item in value should be an integer, numeric or charactor.") } - valueMap[[col]] <- v - } + }) + + # Convert to the named list to an environment to be passed to JVM + valueMap <- convertNamedListToEnv(value) # When value is a named list, caller is expected not to pass in cols if (!is.null(cols)) { @@ -1815,36 +1856,60 @@ setMethod("fillna", sdf <- if (length(cols) == 0) { callJMethod(naFunctions, "fill", value) } else { - callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols))) + callJMethod(naFunctions, "fill", value, as.list(cols)) } dataFrame(sdf) }) -#' crosstab +#' This function downloads the contents of a DataFrame into an R's data.frame. +#' Since data.frames are held in memory, ensure that you have enough memory +#' in your system to accommodate the contents. #' -#' Computes a pair-wise frequency table of the given columns. Also known as a contingency -#' table. The number of distinct values for each column should be less than 1e4. At most 1e6 -#' non-zero pair frequencies will be returned. +#' @title Download data from a DataFrame into a data.frame +#' @param x a DataFrame +#' @return a data.frame +#' @rdname as.data.frame +#' @examples \dontrun{ #' -#' @param col1 name of the first column. Distinct items will make the first item of each row. -#' @param col2 name of the second column. Distinct items will make the column names of the output. -#' @return a local R data.frame representing the contingency table. The first column of each row -#' will be the distinct values of `col1` and the column names will be the distinct values -#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no -#' occurrences will have zero as their counts. +#' irisDF <- createDataFrame(sqlContext, iris) +#' df <- as.data.frame(irisDF[irisDF$Species == "setosa", ]) +#' } +setMethod("as.data.frame", + signature(x = "DataFrame"), + function(x, ...) { + # Check if additional parameters have been passed + if (length(list(...)) > 0) { + stop(paste("Unused argument(s): ", paste(list(...), collapse=", "))) + } + collect(x) + }) + +#' The specified DataFrame is attached to the R search path. This means that +#' the DataFrame is searched by R when evaluating a variable, so columns in +#' the DataFrame can be accessed by simply giving their names. #' -#' @rdname statfunctions -#' @name crosstab -#' @export +#' @rdname attach +#' @title Attach DataFrame to R search path +#' @param what (DataFrame) The DataFrame to attach +#' @param pos (integer) Specify position in search() where to attach. +#' @param name (character) Name to use for the attached DataFrame. Names +#' starting with package: are reserved for library. +#' @param warn.conflicts (logical) If TRUE, warnings are printed about conflicts +#' from attaching the database, unless that DataFrame contains an object #' @examples #' \dontrun{ -#' df <- jsonFile(sqlCtx, "/path/to/file.json") -#' ct = crosstab(df, "title", "gender") +#' attach(irisDf) +#' summary(Sepal_Width) #' } -setMethod("crosstab", - signature(x = "DataFrame", col1 = "character", col2 = "character"), - function(x, col1, col2) { - statFunctions <- callJMethod(x@sdf, "stat") - sct <- callJMethod(statFunctions, "crosstab", col1, col2) - collect(dataFrame(sct)) +#' @seealso \link{detach} +setMethod("attach", + signature(what = "DataFrame"), + function(what, pos = 2, name = deparse(substitute(what)), warn.conflicts = TRUE) { + cols <- columns(what) + stopifnot(length(cols) > 0) + newEnv <- new.env() + for (i in 1:length(cols)) { + assign(x = cols[i], value = what[, cols[i]], envir = newEnv) + } + attach(newEnv, pos = pos, name = name, warn.conflicts = warn.conflicts) }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 1bc6445311473..1bf025cce4376 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -32,6 +32,7 @@ infer_type <- function(x) { numeric = "double", raw = "binary", list = "array", + struct = "struct", environment = "map", Date = "date", POSIXlt = "timestamp", @@ -41,45 +42,45 @@ infer_type <- function(x) { if (type == "map") { stopifnot(length(x) > 0) key <- ls(x)[[1]] - list(type = "map", - keyType = "string", - valueType = infer_type(get(key, x)), - valueContainsNull = TRUE) + paste0("map") } else if (type == "array") { stopifnot(length(x) > 0) + + paste0("array<", infer_type(x[[1]]), ">") + } else if (type == "struct") { + stopifnot(length(x) > 0) names <- names(x) - if (is.null(names)) { - list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE) - } else { - # StructType - types <- lapply(x, infer_type) - fields <- lapply(1:length(x), function(i) { - structField(names[[i]], types[[i]], TRUE) - }) - do.call(structType, fields) - } + stopifnot(!is.null(names)) + + type <- lapply(seq_along(x), function(i) { + paste0(names[[i]], ":", infer_type(x[[i]]), ",") + }) + type <- Reduce(paste0, type) + type <- paste0("struct<", substr(type, 1, nchar(type) - 1), ">") } else if (length(x) > 1) { - list(type = "array", elementType = type, containsNull = TRUE) + paste0("array<", infer_type(x[[1]]), ">") } else { type } } -#' Create a DataFrame from an RDD +#' Create a DataFrame #' -#' Converts an RDD to a DataFrame by infer the types. +#' Converts R data.frame or list into DataFrame. #' #' @param sqlContext A SQLContext #' @param data An RDD or list or data.frame #' @param schema a list of column names or named list (StructType), optional #' @return an DataFrame +#' @rdname createDataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) -#' df <- createDataFrame(sqlContext, rdd) +#' df1 <- as.DataFrame(sqlContext, iris) +#' df2 <- as.DataFrame(sqlContext, list(3,4,5,6)) +#' df3 <- createDataFrame(sqlContext, iris) #' } # TODO(davies): support sampling and infer type from NA @@ -152,6 +153,13 @@ createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0 dataFrame(sdf) } +#' @rdname createDataFrame +#' @aliases createDataFrame +#' @export +as.DataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { + createDataFrame(sqlContext, data, schema, samplingRatio) +} + # toDF # # Converts an RDD to a DataFrame by infer the types. @@ -444,14 +452,21 @@ dropTempTable <- function(sqlContext, tableName) { #' #' @param sqlContext SQLContext to use #' @param path The path of files to load -#' @param source the name of external data source +#' @param source The name of external data source +#' @param schema The data schema defined in structType #' @return DataFrame +#' @rdname read.df +#' @name read.df #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) -#' df <- read.df(sqlContext, "path/to/file.json", source = "json") +#' df1 <- read.df(sqlContext, "path/to/file.json", source = "json") +#' schema <- structType(structField("name", "string"), +#' structField("info", "map")) +#' df2 <- read.df(sqlContext, mapTypeJsonPath, "json", schema) +#' df3 <- loadDF(sqlContext, "data/test_table", "parquet", mergeSchema = "true") #' } read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { @@ -474,9 +489,8 @@ read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) dataFrame(sdf) } -#' @aliases loadDF -#' @export - +#' @rdname read.df +#' @name loadDF loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { read.df(sqlContext, path, source, schema, ...) } diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 4805096f3f9c5..20de3907b7dd9 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -36,13 +36,11 @@ setMethod("initialize", "Column", function(.Object, jc) { .Object }) -column <- function(jc) { - new("Column", jc) -} - -col <- function(x) { - column(callJStatic("org.apache.spark.sql.functions", "col", x)) -} +setMethod("column", + signature(x = "jobj"), + function(x) { + new("Column", x) + }) #' @rdname show #' @name show @@ -211,8 +209,7 @@ setMethod("cast", setMethod("%in%", signature(x = "Column"), function(x, table) { - table <- listToSeq(as.list(table)) - jc <- callJMethod(x@jc, "in", table) + jc <- callJMethod(x@jc, "in", as.list(table)) return(column(jc)) }) diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 6cf628e3007de..f7e56e43016ea 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -50,6 +50,8 @@ readTypedObject <- function(con, type) { "t" = readTime(con), "a" = readArray(con), "l" = readList(con), + "e" = readEnv(con), + "s" = readStruct(con), "n" = NULL, "j" = getJobj(readString(con)), stop(paste("Unsupported type for deserialization", type))) @@ -57,8 +59,10 @@ readTypedObject <- function(con, type) { readString <- function(con) { stringLen <- readInt(con) - string <- readBin(con, raw(), stringLen, endian = "big") - rawToChar(string) + raw <- readBin(con, raw(), stringLen, endian = "big") + string <- rawToChar(raw) + Encoding(string) <- "UTF-8" + string } readInt <- function(con) { @@ -119,6 +123,28 @@ readList <- function(con) { } } +readEnv <- function(con) { + env <- new.env() + len <- readInt(con) + if (len > 0) { + for (i in 1:len) { + key <- readString(con) + value <- readObject(con) + env[[key]] <- value + } + } + env +} + +# Read a field of StructType from DataFrame +# into a named list in R whose class is "struct" +readStruct <- function(con) { + names <- readObject(con) + fields <- readObject(con) + names(fields) <- names + listToStruct(fields) +} + readRaw <- function(con) { dataLen <- readInt(con) readBin(con, raw(), as.integer(dataLen), endian = "big") diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index d848730e70433..a72fb7bb42fef 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -18,16 +18,21 @@ #' @include generics.R column.R NULL -#' Creates a \code{Column} of literal value. +#' lit #' -#' The passed in object is returned directly if it is already a \linkS4class{Column}. -#' If the object is a Scala Symbol, it is converted into a \linkS4class{Column} also. -#' Otherwise, a new \linkS4class{Column} is created to represent the literal value. +#' A new \linkS4class{Column} is created to represent the literal value. +#' If the parameter is a \linkS4class{Column}, it is returned unchanged. #' #' @family normal_funcs #' @rdname lit #' @name lit #' @export +#' @examples +#' \dontrun{ +#' lit(df$name) +#' select(df, lit("x")) +#' select(df, lit("2015-01-01")) +#'} setMethod("lit", signature("ANY"), function(x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -233,6 +238,28 @@ setMethod("ceil", column(jc) }) +#' Though scala functions has "col" function, we don't expose it in SparkR +#' because we don't want to conflict with the "col" function in the R base +#' package and we also have "column" function exported which is an alias of "col". +col <- function(x) { + column(callJStatic("org.apache.spark.sql.functions", "col", x)) +} + +#' column +#' +#' Returns a Column based on the given column name. +#' +#' @rdname col +#' @name column +#' @family normal_funcs +#' @export +#' @examples \dontrun{column(df)} +setMethod("column", + signature(x = "character"), + function(x) { + col(x) + }) + #' cos #' #' Computes the cosine of the given value. @@ -1331,7 +1358,7 @@ setMethod("countDistinct", x@jc }) jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc, - listToSeq(jcol)) + jcol) column(jc) }) @@ -1348,7 +1375,7 @@ setMethod("concat", signature(x = "Column"), function(x, ...) { jcols <- lapply(list(x, ...), function(x) { x@jc }) - jc <- callJStatic("org.apache.spark.sql.functions", "concat", listToSeq(jcols)) + jc <- callJStatic("org.apache.spark.sql.functions", "concat", jcols) column(jc) }) @@ -1366,7 +1393,7 @@ setMethod("greatest", function(x, ...) { stopifnot(length(list(...)) > 0) jcols <- lapply(list(x, ...), function(x) { x@jc }) - jc <- callJStatic("org.apache.spark.sql.functions", "greatest", listToSeq(jcols)) + jc <- callJStatic("org.apache.spark.sql.functions", "greatest", jcols) column(jc) }) @@ -1384,7 +1411,7 @@ setMethod("least", function(x, ...) { stopifnot(length(list(...)) > 0) jcols <- lapply(list(x, ...), function(x) { x@jc }) - jc <- callJStatic("org.apache.spark.sql.functions", "least", listToSeq(jcols)) + jc <- callJStatic("org.apache.spark.sql.functions", "least", jcols) column(jc) }) @@ -1675,7 +1702,7 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @export setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { - jcols <- listToSeq(lapply(list(x, ...), function(x) { x@jc })) + jcols <- lapply(list(x, ...), function(x) { x@jc }) jc <- callJStatic("org.apache.spark.sql.functions", "concat_ws", sep, jcols) column(jc) }) @@ -1723,7 +1750,7 @@ setMethod("expr", signature(x = "character"), #' @export setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { - jcols <- listToSeq(lapply(list(x, ...), function(arg) { arg@jc })) + jcols <- lapply(list(x, ...), function(arg) { arg@jc }) jc <- callJStatic("org.apache.spark.sql.functions", "format_string", format, jcols) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 43dd8d283ab6b..4a419f785e92c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -63,6 +63,10 @@ setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) # @export setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) +# @rdname statfunctions +# @export +setGeneric("freqItems", function(x, cols, support = 0.01) { standardGeneric("freqItems") }) + # @rdname distinct # @export setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) @@ -399,6 +403,14 @@ setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) #' @export setGeneric("columns", function(x) {standardGeneric("columns") }) +#' @rdname statfunctions +#' @export +setGeneric("cov", function(x, col1, col2) {standardGeneric("cov") }) + +#' @rdname statfunctions +#' @export +setGeneric("corr", function(x, col1, col2, method = "pearson") {standardGeneric("corr") }) + #' @rdname describe #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) @@ -497,6 +509,10 @@ setGeneric("sample", setGeneric("sample_frac", function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") }) +#' @rdname statfunctions +#' @export +setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") }) + #' @rdname saveAsParquetFile #' @export setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) @@ -674,6 +690,10 @@ setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) #' @export setGeneric("ceil", function(x) { standardGeneric("ceil") }) +#' @rdname col +#' @export +setGeneric("column", function(x) { standardGeneric("column") }) + #' @rdname concat #' @export setGeneric("concat", function(x, ...) { standardGeneric("concat") }) @@ -983,3 +1003,11 @@ setGeneric("glm") #' @rdname rbind #' @export setGeneric("rbind", signature = "...") + +#' @rdname as.data.frame +#' @export +setGeneric("as.data.frame") + +#' @rdname attach +#' @export +setGeneric("attach") diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 576ac72f40fc0..4cab1a69f601a 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -102,7 +102,7 @@ setMethod("agg", } } jcols <- lapply(cols, function(c) { c@jc }) - sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1])) + sdf <- callJMethod(x@sgd, "agg", jcols[[1]], jcols[-1]) } else { stop("agg can only support Column or character") } @@ -124,7 +124,7 @@ createMethod <- function(name) { setMethod(name, signature(x = "GroupedData"), function(x, ...) { - sdf <- callJMethod(x@sgd, name, toSeq(...)) + sdf <- callJMethod(x@sgd, name, list(...)) dataFrame(sdf) }) } diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index cea3d760d05fe..25615e805e03c 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj")) #' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. #' #' @param formula A symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '+', '-', and '.'. +#' operators are supported, including '~', '.', ':', '+', and '-'. #' @param data DataFrame for training #' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. #' @param lambda Regularization parameter @@ -41,14 +41,16 @@ setClass("PipelineModel", representation(model = "jobj")) #' sqlContext <- sparkRSQL.init(sc) #' data(iris) #' df <- createDataFrame(sqlContext, iris) -#' model <- glm(Sepal_Length ~ Sepal_Width, df) +#' model <- glm(Sepal_Length ~ Sepal_Width, df, family="gaussian") +#' summary(model) #'} setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), - function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) { + function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, + solver = "auto") { family <- match.arg(family) model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitRModelFormula", deparse(formula), data@sdf, family, lambda, - alpha) + alpha, solver) return(new("PipelineModel", model = model)) }) diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 79c744ef29c23..6f0e9a94e9bfa 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -56,7 +56,7 @@ structType.structField <- function(x, ...) { }) stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createStructType", - listToSeq(sfObjList)) + sfObjList) structType(stObj) } @@ -114,6 +114,79 @@ structField.jobj <- function(x) { obj } +checkType <- function(type) { + primtiveTypes <- c("byte", + "integer", + "float", + "double", + "numeric", + "character", + "string", + "binary", + "raw", + "logical", + "boolean", + "timestamp", + "date") + if (type %in% primtiveTypes) { + return() + } else { + # Check complex types + firstChar <- substr(type, 1, 1) + switch (firstChar, + a = { + # Array type + m <- regexec("^array<(.+)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + elemType <- matchedStrings[[1]][2] + checkType(elemType) + return() + } + }, + m = { + # Map type + m <- regexec("^map<(.+),(.+)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 3) { + keyType <- matchedStrings[[1]][2] + if (keyType != "string" && keyType != "character") { + stop("Key type in a map must be string or character") + } + valueType <- matchedStrings[[1]][3] + checkType(valueType) + return() + } + }, + s = { + # Struct type + m <- regexec("^struct<(.+)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + fieldsString <- matchedStrings[[1]][2] + # strsplit does not return the final empty string, so check if + # the final char is "," + if (substr(fieldsString, nchar(fieldsString), nchar(fieldsString)) != ",") { + fields <- strsplit(fieldsString, ",")[[1]] + for (field in fields) { + m <- regexec("^(.+):(.+)$", field) + matchedStrings <- regmatches(field, m) + if (length(matchedStrings[[1]]) >= 3) { + fieldType <- matchedStrings[[1]][3] + checkType(fieldType) + } else { + break + } + } + return() + } + } + }) + } + + stop(paste("Unsupported type for Dataframe:", type)) +} + structField.character <- function(x, type, nullable = TRUE) { if (class(x) != "character") { stop("Field name must be a string.") @@ -124,28 +197,13 @@ structField.character <- function(x, type, nullable = TRUE) { if (class(nullable) != "logical") { stop("nullable must be either TRUE or FALSE") } - options <- c("byte", - "integer", - "float", - "double", - "numeric", - "character", - "string", - "binary", - "raw", - "logical", - "boolean", - "timestamp", - "date") - dataType <- if (type %in% options) { - type - } else { - stop(paste("Unsupported type for Dataframe:", type)) - } + + checkType(type) + sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createStructField", x, - dataType, + type, nullable) structField(sfObj) } diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index e3676f57f907f..17082b4e52fcf 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -32,6 +32,21 @@ # environment -> Map[String, T], where T is a native type # jobj -> Object, where jobj is an object created in the backend +getSerdeType <- function(object) { + type <- class(object)[[1]] + if (type != "list") { + type + } else { + # Check if all elements are of same type + elemType <- unique(sapply(object, function(elem) { getSerdeType(elem) })) + if (length(elemType) <= 1) { + "array" + } else { + "list" + } + } +} + writeObject <- function(con, object, writeType = TRUE) { # NOTE: In R vectors have same type as objects. So we don't support # passing in vectors as arrays and instead require arrays to be passed @@ -45,10 +60,12 @@ writeObject <- function(con, object, writeType = TRUE) { type <- "NULL" } } + + serdeType <- getSerdeType(object) if (writeType) { - writeType(con, type) + writeType(con, serdeType) } - switch(type, + switch(serdeType, NULL = writeVoid(con), integer = writeInt(con, object), character = writeString(con, object), @@ -56,7 +73,9 @@ writeObject <- function(con, object, writeType = TRUE) { double = writeDouble(con, object), numeric = writeDouble(con, object), raw = writeRaw(con, object), + array = writeArray(con, object), list = writeList(con, object), + struct = writeList(con, object), jobj = writeJobj(con, object), environment = writeEnv(con, object), Date = writeDate(con, object), @@ -79,7 +98,7 @@ writeJobj <- function(con, value) { writeString <- function(con, value) { utfVal <- enc2utf8(value) writeInt(con, as.integer(nchar(utfVal, type = "bytes") + 1)) - writeBin(utfVal, con, endian = "big") + writeBin(utfVal, con, endian = "big", useBytes=TRUE) } writeInt <- function(con, value) { @@ -110,7 +129,7 @@ writeRowSerialize <- function(outputCon, rows) { serializeRow <- function(row) { rawObj <- rawConnection(raw(0), "wb") on.exit(close(rawObj)) - writeGenericList(rawObj, row) + writeList(rawObj, row) rawConnectionValue(rawObj) } @@ -128,7 +147,9 @@ writeType <- function(con, class) { double = "d", numeric = "d", raw = "r", + array = "a", list = "l", + struct = "s", jobj = "j", environment = "e", Date = "D", @@ -139,15 +160,13 @@ writeType <- function(con, class) { } # Used to pass arrays where all the elements are of the same type -writeList <- function(con, arr) { - # All elements should be of same type - elemType <- unique(sapply(arr, function(elem) { class(elem) })) - stopifnot(length(elemType) <= 1) - +writeArray <- function(con, arr) { # TODO: Empty lists are given type "character" right now. # This may not work if the Java side expects array of any other type. - if (length(elemType) == 0) { + if (length(arr) == 0) { elemType <- class("somestring") + } else { + elemType <- getSerdeType(arr[[1]]) } writeType(con, elemType) @@ -161,7 +180,7 @@ writeList <- function(con, arr) { } # Used to pass arrays where the elements can be of different types -writeGenericList <- function(con, list) { +writeList <- function(con, list) { writeInt(con, length(list)) for (elem in list) { writeObject(con, elem) @@ -174,9 +193,9 @@ writeEnv <- function(con, env) { writeInt(con, len) if (len > 0) { - writeList(con, as.list(ls(env))) + writeArray(con, as.list(ls(env))) vals <- lapply(ls(env), function(x) { env[[x]] }) - writeGenericList(con, as.list(vals)) + writeList(con, as.list(vals)) } } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 3c57a44db257d..043b0057bd04a 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -39,6 +39,14 @@ sparkR.stop <- function() { sc <- get(".sparkRjsc", envir = env) callJMethod(sc, "stop") rm(".sparkRjsc", envir = env) + + if (exists(".sparkRSQLsc", envir = env)) { + rm(".sparkRSQLsc", envir = env) + } + + if (exists(".sparkRHivesc", envir = env)) { + rm(".sparkRHivesc", envir = env) + } } if (exists(".backendLaunched", envir = env)) { @@ -163,22 +171,16 @@ sparkR.init <- function( sparkHome <- suppressWarnings(normalizePath(sparkHome)) } - sparkEnvirMap <- new.env() - for (varname in names(sparkEnvir)) { - sparkEnvirMap[[varname]] <- sparkEnvir[[varname]] - } + sparkEnvirMap <- convertNamedListToEnv(sparkEnvir) - sparkExecutorEnvMap <- new.env() - if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) { + sparkExecutorEnvMap <- convertNamedListToEnv(sparkExecutorEnv) + if(is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) { sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH")) } - for (varname in names(sparkExecutorEnv)) { - sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]] - } nonEmptyJars <- Filter(function(x) { x != "" }, jars) - localJarPaths <- sapply(nonEmptyJars, + localJarPaths <- lapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) }) # Set the start time to identify jobjs @@ -193,7 +195,7 @@ sparkR.init <- function( master, appName, as.character(sparkHome), - as.list(localJarPaths), + localJarPaths, sparkEnvirMap, sparkExecutorEnvMap), envir = .sparkREnv diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R new file mode 100644 index 0000000000000..f79329b115404 --- /dev/null +++ b/R/pkg/R/stats.R @@ -0,0 +1,161 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# stats.R - Statistic functions for DataFrames. + +setOldClass("jobj") + +#' crosstab +#' +#' Computes a pair-wise frequency table of the given columns. Also known as a contingency +#' table. The number of distinct values for each column should be less than 1e4. At most 1e6 +#' non-zero pair frequencies will be returned. +#' +#' @param col1 name of the first column. Distinct items will make the first item of each row. +#' @param col2 name of the second column. Distinct items will make the column names of the output. +#' @return a local R data.frame representing the contingency table. The first column of each row +#' will be the distinct values of `col1` and the column names will be the distinct values +#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no +#' occurrences will have zero as their counts. +#' +#' @rdname statfunctions +#' @name crosstab +#' @export +#' @examples +#' \dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' ct <- crosstab(df, "title", "gender") +#' } +setMethod("crosstab", + signature(x = "DataFrame", col1 = "character", col2 = "character"), + function(x, col1, col2) { + statFunctions <- callJMethod(x@sdf, "stat") + sct <- callJMethod(statFunctions, "crosstab", col1, col2) + collect(dataFrame(sct)) + }) + +#' cov +#' +#' Calculate the sample covariance of two numerical columns of a DataFrame. +#' +#' @param x A SparkSQL DataFrame +#' @param col1 the name of the first column +#' @param col2 the name of the second column +#' @return the covariance of the two columns. +#' +#' @rdname statfunctions +#' @name cov +#' @export +#' @examples +#'\dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' cov <- cov(df, "title", "gender") +#' } +setMethod("cov", + signature(x = "DataFrame", col1 = "character", col2 = "character"), + function(x, col1, col2) { + statFunctions <- callJMethod(x@sdf, "stat") + callJMethod(statFunctions, "cov", col1, col2) + }) + +#' corr +#' +#' Calculates the correlation of two columns of a DataFrame. +#' Currently only supports the Pearson Correlation Coefficient. +#' For Spearman Correlation, consider using RDD methods found in MLlib's Statistics. +#' +#' @param x A SparkSQL DataFrame +#' @param col1 the name of the first column +#' @param col2 the name of the second column +#' @param method Optional. A character specifying the method for calculating the correlation. +#' only "pearson" is allowed now. +#' @return The Pearson Correlation Coefficient as a Double. +#' +#' @rdname statfunctions +#' @name corr +#' @export +#' @examples +#'\dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' corr <- corr(df, "title", "gender") +#' corr <- corr(df, "title", "gender", method = "pearson") +#' } +setMethod("corr", + signature(x = "DataFrame", col1 = "character", col2 = "character"), + function(x, col1, col2, method = "pearson") { + statFunctions <- callJMethod(x@sdf, "stat") + callJMethod(statFunctions, "corr", col1, col2, method) + }) + +#' freqItems +#' +#' Finding frequent items for columns, possibly with false positives. +#' Using the frequent element count algorithm described in +#' \url{http://dx.doi.org/10.1145/762471.762473}, proposed by Karp, Schenker, and Papadimitriou. +#' +#' @param x A SparkSQL DataFrame. +#' @param cols A vector column names to search frequent items in. +#' @param support (Optional) The minimum frequency for an item to be considered `frequent`. +#' Should be greater than 1e-4. Default support = 0.01. +#' @return a local R data.frame with the frequent items in each column +#' +#' @rdname statfunctions +#' @name freqItems +#' @export +#' @examples +#' \dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' fi = freqItems(df, c("title", "gender")) +#' } +setMethod("freqItems", signature(x = "DataFrame", cols = "character"), + function(x, cols, support = 0.01) { + statFunctions <- callJMethod(x@sdf, "stat") + sct <- callJMethod(statFunctions, "freqItems", as.list(cols), support) + collect(dataFrame(sct)) + }) + +#' sampleBy +#' +#' Returns a stratified sample without replacement based on the fraction given on each stratum. +#' +#' @param x A SparkSQL DataFrame +#' @param col column that defines strata +#' @param fractions A named list giving sampling fraction for each stratum. If a stratum is +#' not specified, we treat its fraction as zero. +#' @param seed random seed +#' @return A new DataFrame that represents the stratified sample +#' +#' @rdname statfunctions +#' @name sampleBy +#' @export +#' @examples +#'\dontrun{ +#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' sample <- sampleBy(df, "key", fractions, 36) +#' } +setMethod("sampleBy", + signature(x = "DataFrame", col = "character", + fractions = "list", seed = "numeric"), + function(x, col, fractions, seed) { + fractionsEnv <- convertNamedListToEnv(fractions) + + statFunctions <- callJMethod(x@sdf, "stat") + # Seed is expected to be Long on Scala side, here convert it to an integer + # due to SerDe limitation now. + sdf <- callJMethod(statFunctions, "sampleBy", col, fractionsEnv, as.integer(seed)) + dataFrame(sdf) + }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 3babcb519378e..0b9e2957fe9a5 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -361,16 +361,6 @@ numToInt <- function(num) { as.integer(num) } -# create a Seq in JVM -toSeq <- function(...) { - callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...)) -} - -# create a Seq in JVM from a list -listToSeq <- function(l) { - callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l) -} - # Utility function to recursively traverse the Abstract Syntax Tree (AST) of a # user defined function (UDF), and to examine variables in the UDF to decide # if their values should be included in the new function environment. @@ -598,3 +588,38 @@ mergePartitions <- function(rdd, zip) { PipelinedRDD(rdd, partitionFunc) } + +# Convert a named list to struct so that +# SerDe won't confuse between a normal named list and struct +listToStruct <- function(list) { + stopifnot(class(list) == "list") + stopifnot(!is.null(names(list))) + class(list) <- "struct" + list +} + +# Convert a struct to a named list +structToList <- function(struct) { + stopifnot(class(list) == "struct") + + class(struct) <- "list" + struct +} + +# Convert a named list to an environment to be passed to JVM +convertNamedListToEnv <- function(namedList) { + # Make sure each item in the list has a name + names <- names(namedList) + stopifnot( + if (is.null(names)) { + length(namedList) == 0 + } else { + !any(is.na(names)) + }) + + env <- new.env() + for (name in names) { + env[[name]] <- namedList[[name]] + } + env +} diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R index 513bbc8e62059..e99815ed1562c 100644 --- a/R/pkg/inst/tests/test_context.R +++ b/R/pkg/inst/tests/test_context.R @@ -26,6 +26,16 @@ test_that("repeatedly starting and stopping SparkR", { } }) +test_that("repeatedly starting and stopping SparkR SQL", { + for (i in 1:4) { + sc <- sparkR.init() + sqlContext <- sparkRSQL.init(sc) + df <- createDataFrame(sqlContext, data.frame(a = 1:20)) + expect_equal(count(df), 20) + sparkR.stop() + } +}) + test_that("rdd GC across sparkR.stop", { sparkR.stop() sc <- sparkR.init() # sc should get id 0 diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index f272de78ad4a6..3331ce738358c 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -49,13 +49,21 @@ test_that("dot minus and intercept vs native glm", { expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) +test_that("feature interaction vs native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + test_that("summary coefficients match with native glm", { training <- createDataFrame(sqlContext, iris) - stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs")) coefs <- as.vector(stats$coefficients) rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) expect_true(all(abs(rCoefs - coefs) < 1e-6)) expect_true(all( as.character(stats$features) == - c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica"))) + c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 0da5e38654732..67d8b23cd7b8d 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -49,26 +49,29 @@ mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLinesNa, jsonPathNa) -test_that("infer types", { +# For test complex types in DataFrame +mockLinesComplexType <- + c("{\"c1\":[1, 2, 3], \"c2\":[\"a\", \"b\", \"c\"], \"c3\":[1.0, 2.0, 3.0]}", + "{\"c1\":[4, 5, 6], \"c2\":[\"d\", \"e\", \"f\"], \"c3\":[4.0, 5.0, 6.0]}", + "{\"c1\":[7, 8, 9], \"c2\":[\"g\", \"h\", \"i\"], \"c3\":[7.0, 8.0, 9.0]}") +complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesComplexType, complexTypeJsonPath) + +test_that("infer types and check types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") expect_equal(infer_type("abc"), "string") expect_equal(infer_type(TRUE), "boolean") expect_equal(infer_type(as.Date("2015-03-11")), "date") expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp") - expect_equal(infer_type(c(1L, 2L)), - list(type = "array", elementType = "integer", containsNull = TRUE)) - expect_equal(infer_type(list(1L, 2L)), - list(type = "array", elementType = "integer", containsNull = TRUE)) - testStruct <- infer_type(list(a = 1L, b = "2")) - expect_equal(class(testStruct), "structType") - checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) - checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) + expect_equal(infer_type(c(1L, 2L)), "array") + expect_equal(infer_type(list(1L, 2L)), "array") + expect_equal(infer_type(listToStruct(list(a = 1L, b = "2"))), "struct") e <- new.env() assign("a", 1L, envir = e) - expect_equal(infer_type(e), - list(type = "map", keyType = "string", valueType = "integer", - valueContainsNull = TRUE)) + expect_equal(infer_type(e), "map") + + expect_error(checkType("map"), "Key type in a map must be string or character") }) test_that("structType and structField", { @@ -86,17 +89,28 @@ test_that("structType and structField", { test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(sqlContext, rdd, list("a", "b")) + dfAsDF <- as.DataFrame(sqlContext, rdd, list("a", "b")) expect_is(df, "DataFrame") + expect_is(dfAsDF, "DataFrame") expect_equal(count(df), 10) + expect_equal(count(dfAsDF), 10) expect_equal(nrow(df), 10) + expect_equal(nrow(dfAsDF), 10) expect_equal(ncol(df), 2) + expect_equal(ncol(dfAsDF), 2) expect_equal(dim(df), c(10, 2)) + expect_equal(dim(dfAsDF), c(10, 2)) expect_equal(columns(df), c("a", "b")) + expect_equal(columns(dfAsDF), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(dtypes(dfAsDF), list(c("a", "int"), c("b", "string"))) df <- createDataFrame(sqlContext, rdd) + dfAsDF <- as.DataFrame(sqlContext, rdd) expect_is(df, "DataFrame") + expect_is(dfAsDF, "DataFrame") expect_equal(columns(df), c("_1", "_2")) + expect_equal(columns(dfAsDF), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) @@ -127,9 +141,13 @@ test_that("create DataFrame from RDD", { schema <- structType(structField("name", "string"), structField("age", "integer"), structField("height", "float")) df2 <- createDataFrame(sqlContext, df.toRDD, schema) + df2AsDF <- as.DataFrame(sqlContext, df.toRDD, schema) expect_equal(columns(df2), c("name", "age", "height")) + expect_equal(columns(df2AsDF), c("name", "age", "height")) expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) + expect_equal(dtypes(df2AsDF), list(c("name", "string"), c("age", "int"), c("height", "float"))) expect_equal(collect(where(df2, df2$name == "Bob")), c("Bob", 16, 176.5)) + expect_equal(collect(where(df2AsDF, df2$name == "Bob")), c("Bob", 16, 176.5)) localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18), @@ -236,18 +254,84 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) -# TODO: enable this test after fix serialization for nested object -#test_that("create DataFrame with nested array and struct", { -# e <- new.env() -# assign("n", 3L, envir = e) -# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) -# df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) -# expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), -# c("c", "map"), c("d", "struct"))) -# expect_equal(count(df), 1) -# ldf <- collect(df) -# expect_equal(ldf[1,], l[[1]]) -#}) +test_that("create DataFrame with complex types", { + e <- new.env() + assign("n", 3L, envir = e) + + s <- listToStruct(list(a = "aa", b = 3L)) + + l <- list(as.list(1:10), list("a", "b"), e, s) + df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) + expect_equal(dtypes(df), list(c("a", "array"), + c("b", "array"), + c("c", "map"), + c("d", "struct"))) + expect_equal(count(df), 1) + ldf <- collect(df) + expect_equal(names(ldf), c("a", "b", "c", "d")) + expect_equal(ldf[1, 1][[1]], l[[1]]) + expect_equal(ldf[1, 2][[1]], l[[2]]) + + e <- ldf$c[[1]] + expect_equal(class(e), "environment") + expect_equal(ls(e), "n") + expect_equal(e$n, 3L) + + s <- ldf$d[[1]] + expect_equal(class(s), "struct") + expect_equal(s$a, "aa") + expect_equal(s$b, 3L) +}) + +# For test map type and struct type in DataFrame +mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", + "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", + "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") +mapTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesMapType, mapTypeJsonPath) + +test_that("Collect DataFrame with complex types", { + # ArrayType + df <- jsonFile(sqlContext, complexTypeJsonPath) + + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 3) + expect_equal(names(ldf), c("c1", "c2", "c3")) + expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) + expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) + expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) + + # MapType + schema <- structType(structField("name", "string"), + structField("info", "map")) + df <- read.df(sqlContext, mapTypeJsonPath, "json", schema) + expect_equal(dtypes(df), list(c("name", "string"), + c("info", "map"))) + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 2) + expect_equal(names(ldf), c("name", "info")) + expect_equal(ldf$name, c("Bob", "Alice", "David")) + bob <- ldf$info[[1]] + expect_equal(class(bob), "environment") + expect_equal(bob$age, 16) + expect_equal(bob$height, 176.5) + + # StructType + df <- jsonFile(sqlContext, mapTypeJsonPath) + expect_equal(dtypes(df), list(c("info", "struct"), + c("name", "string"))) + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 2) + expect_equal(names(ldf), c("info", "name")) + expect_equal(ldf$name, c("Bob", "Alice", "David")) + bob <- ldf$info[[1]] + expect_equal(class(bob), "struct") + expect_equal(bob$age, 16) + expect_equal(bob$height, 176.5) +}) test_that("jsonFile() on a local file returns a DataFrame", { df <- jsonFile(sqlContext, jsonPath) @@ -431,6 +515,32 @@ test_that("collect() and take() on a DataFrame return the same number of rows an expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) +test_that("collect() support Unicode characters", { + markUtf8 <- function(s) { + Encoding(s) <- "UTF-8" + s + } + + lines <- c("{\"name\":\"안녕하세요\"}", + "{\"name\":\"您好\", \"age\":30}", + "{\"name\":\"こんにちは\", \"age\":19}", + "{\"name\":\"Xin chào\"}") + + jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPath) + + df <- read.df(sqlContext, jsonPath, "json") + rdf <- collect(df) + expect_true(is.data.frame(rdf)) + expect_equal(rdf$name[1], markUtf8("안녕하세요")) + expect_equal(rdf$name[2], markUtf8("您好")) + expect_equal(rdf$name[3], markUtf8("こんにちは")) + expect_equal(rdf$name[4], markUtf8("Xin chào")) + + df1 <- createDataFrame(sqlContext, rdf) + expect_equal(collect(where(df1, df1$name == markUtf8("您好")))$name, markUtf8("您好")) +}) + test_that("multiple pipeline transformations result in an RDD with the correct values", { df <- jsonFile(sqlContext, jsonPath) first <- lapply(df, function(row) { @@ -585,6 +695,13 @@ test_that("select with column", { expect_equal(columns(df3), c("x")) expect_equal(count(df3), 3) expect_equal(collect(select(df3, "x"))[[1, 1]], "x") + + df4 <- select(df, c("name", "age")) + expect_equal(columns(df4), c("name", "age")) + expect_equal(count(df4), 3) + + expect_error(select(df, c("name", "age"), "name"), + "To select multiple columns, use a character vector or list for col") }) test_that("subsetting", { @@ -692,7 +809,7 @@ test_that("test HiveContext", { }) test_that("column operators", { - c <- SparkR:::col("a") + c <- column("a") c2 <- (- c + 1 - 2) * 3 / 4.0 c3 <- (c + c2 - c2) * c2 %% c2 c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) @@ -700,7 +817,7 @@ test_that("column operators", { }) test_that("column functions", { - c <- SparkR:::col("a") + c <- column("a") c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) c3 <- cosh(c) + count(c) + crc32(c) + exp(c) @@ -894,7 +1011,7 @@ test_that("arrange() and orderBy() on a DataFrame", { sorted <- arrange(df, df$age) expect_equal(collect(sorted)[1,2], "Michael") - sorted2 <- arrange(df, "name") + sorted2 <- arrange(df, "name", decreasing = FALSE) expect_equal(collect(sorted2)[2,"age"], 19) sorted3 <- orderBy(df, asc(df$age)) @@ -904,6 +1021,15 @@ test_that("arrange() and orderBy() on a DataFrame", { sorted4 <- orderBy(df, desc(df$name)) expect_equal(first(sorted4)$name, "Michael") expect_equal(collect(sorted4)[3,"name"], "Andy") + + sorted5 <- arrange(df, "age", "name", decreasing = TRUE) + expect_equal(collect(sorted5)[1,2], "Andy") + + sorted6 <- arrange(df, "age","name", decreasing = c(T, F)) + expect_equal(collect(sorted6)[1,2], "Andy") + + sorted7 <- arrange(df, "name", decreasing = FALSE) + expect_equal(collect(sorted7)[2,"age"], 19) }) test_that("filter() on a DataFrame", { @@ -945,7 +1071,7 @@ test_that("join() and merge() on a DataFrame", { expect_equal(names(joined2), c("age", "name", "name", "test")) expect_equal(count(joined2), 3) - joined3 <- join(df, df2, df$name == df2$name, "right_outer") + joined3 <- join(df, df2, df$name == df2$name, "rightouter") expect_equal(names(joined3), c("age", "name", "name", "test")) expect_equal(count(joined3), 4) expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) @@ -956,11 +1082,34 @@ test_that("join() and merge() on a DataFrame", { expect_equal(count(joined4), 4) expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) + joined5 <- join(df, df2, df$name == df2$name, "leftouter") + expect_equal(names(joined5), c("age", "name", "name", "test")) + expect_equal(count(joined5), 3) + expect_true(is.na(collect(orderBy(joined5, joined5$age))$age[1])) + + joined6 <- join(df, df2, df$name == df2$name, "inner") + expect_equal(names(joined6), c("age", "name", "name", "test")) + expect_equal(count(joined6), 3) + + joined7 <- join(df, df2, df$name == df2$name, "leftsemi") + expect_equal(names(joined7), c("age", "name")) + expect_equal(count(joined7), 3) + + joined8 <- join(df, df2, df$name == df2$name, "left_outer") + expect_equal(names(joined8), c("age", "name", "name", "test")) + expect_equal(count(joined8), 3) + expect_true(is.na(collect(orderBy(joined8, joined8$age))$age[1])) + + joined9 <- join(df, df2, df$name == df2$name, "right_outer") + expect_equal(names(joined9), c("age", "name", "name", "test")) + expect_equal(count(joined9), 4) + expect_true(is.na(collect(orderBy(joined9, joined9$age))$age[2])) + merged <- select(merge(df, df2, df$name == df2$name, "outer"), alias(df$age + 5, "newAge"), df$name, df2$test) expect_equal(names(merged), c("newAge", "name", "test")) expect_equal(count(merged), 4) - expect_equal(collect(orderBy(merged, joined4$name))$newAge[3], 24) + expect_equal(collect(orderBy(merged, merged$name))$newAge[3], 24) }) test_that("toJSON() returns an RDD of the correct values", { @@ -1091,7 +1240,7 @@ test_that("describe() and summarize() on a DataFrame", { stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") expect_equal(collect(stats)[2, "age"], "24.5") - expect_equal(collect(stats)[3, "age"], "5.5") + expect_equal(collect(stats)[3, "age"], "7.7781745930520225") stats <- describe(df) expect_equal(collect(stats)[4, "name"], "Andy") expect_equal(collect(stats)[5, "age"], "30") @@ -1234,9 +1383,79 @@ test_that("crosstab() on a DataFrame", { expect_identical(expected, ordered) }) +test_that("cov() and corr() on a DataFrame", { + l <- lapply(c(0:9), function(x) { list(x, x * 2.0) }) + df <- createDataFrame(sqlContext, l, c("singles", "doubles")) + result <- cov(df, "singles", "doubles") + expect_true(abs(result - 55.0 / 3) < 1e-12) + + result <- corr(df, "singles", "doubles") + expect_true(abs(result - 1.0) < 1e-12) + result <- corr(df, "singles", "doubles", "pearson") + expect_true(abs(result - 1.0) < 1e-12) +}) + +test_that("freqItems() on a DataFrame", { + input <- 1:1000 + rdf <- data.frame(numbers = input, letters = as.character(input), + negDoubles = input * -1.0, stringsAsFactors = F) + rdf[ input %% 3 == 0, ] <- c(1, "1", -1) + df <- createDataFrame(sqlContext, rdf) + multiColResults <- freqItems(df, c("numbers", "letters"), support=0.1) + expect_true(1 %in% multiColResults$numbers[[1]]) + expect_true("1" %in% multiColResults$letters[[1]]) + singleColResult <- freqItems(df, "negDoubles", support=0.1) + expect_true(-1 %in% head(singleColResult$negDoubles)[[1]]) + + l <- lapply(c(0:99), function(i) { + if (i %% 2 == 0) { list(1L, -1.0) } + else { list(i, i * -1.0) }}) + df <- createDataFrame(sqlContext, l, c("a", "b")) + result <- freqItems(df, c("a", "b"), 0.4) + expect_identical(result[[1]], list(list(1L, 99L))) + expect_identical(result[[2]], list(list(-1, -99))) +}) + +test_that("sampleBy() on a DataFrame", { + l <- lapply(c(0:99), function(i) { as.character(i %% 3) }) + df <- createDataFrame(sqlContext, l, "key") + fractions <- list("0" = 0.1, "1" = 0.2) + sample <- sampleBy(df, "key", fractions, 0) + result <- collect(orderBy(count(groupBy(sample, "key")), "key")) + expect_identical(as.list(result[1, ]), list(key = "0", count = 2)) + expect_identical(as.list(result[2, ]), list(key = "1", count = 10)) +}) + test_that("SQL error message is returned from JVM", { retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) - expect_equal(grepl("Table Not Found: blah", retError), TRUE) + expect_equal(grepl("Table not found: blah", retError), TRUE) +}) + +test_that("Method as.data.frame as a synonym for collect()", { + irisDF <- createDataFrame(sqlContext, iris) + expect_equal(as.data.frame(irisDF), collect(irisDF)) + irisDF2 <- irisDF[irisDF$Species == "setosa", ] + expect_equal(as.data.frame(irisDF2), collect(irisDF2)) +}) + +test_that("attach() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + expect_error(age) + attach(df) + expect_is(age, "DataFrame") + expected_age <- data.frame(age = c(NA, 30, 19)) + expect_equal(head(age), expected_age) + stat <- summary(age) + expect_equal(collect(stat)[5, "age"], "30") + age <- age$age + 1 + expect_is(age, "Column") + rm(age) + stat2 <- summary(age) + expect_equal(collect(stat2)[5, "age"], "30") + detach("df") + stat3 <- summary(df[, "age"]) + expect_equal(collect(stat3)[5, "age"], "30") + expect_error(age) }) unlink(parquetPath) diff --git a/README.md b/README.md index 380422ca00dbe..4116ef3563879 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Apache Spark Spark is a fast and general cluster computing system for Big Data. It provides -high-level APIs in Scala, Java, and Python, and an optimized engine that +high-level APIs in Scala, Java, Python, and R, and an optimized engine that supports general computation graphs for data analysis. It also supports a rich set of higher-level tools including Spark SQL for SQL and DataFrames, MLlib for machine learning, GraphX for graph processing, @@ -59,7 +59,7 @@ will run the Pi example locally. You can set the MASTER environment variable when running examples to submit examples to a cluster. This can be a mesos:// or spark:// URL, -"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run +"yarn" to run on YARN, and "local" to run locally with one thread, or "local[N]" to run locally with N threads. You can also use an abbreviated class name if the class is in the `examples` package. For instance: @@ -94,5 +94,5 @@ distribution. ## Configuration -Please refer to the [Configuration guide](http://spark.apache.org/docs/latest/configuration.html) +Please refer to the [Configuration Guide](http://spark.apache.org/docs/latest/configuration.html) in the online documentation for an overview on how to configure Spark. diff --git a/assembly/pom.xml b/assembly/pom.xml index e9c6d26ccddc7..4b60ee00ffbe5 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index ed5c37e595a96..672e9469aec92 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -52,6 +52,10 @@ scalacheck_${scala.binary.version} test + + org.apache.spark + spark-test-tags_${scala.binary.version} + target/scala-${scala.binary.version}/classes diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala index ef0bb2ac13f08..8399033ac61ec 100644 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala @@ -22,6 +22,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") object Bagel extends Logging { val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK @@ -78,7 +79,7 @@ object Bagel extends Logging { val startTime = System.currentTimeMillis val aggregated = agg(verts, aggregator) - val combinedMsgs = msgs.combineByKey( + val combinedMsgs = msgs.combineByKeyWithClassTag( combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner) val grouped = combinedMsgs.groupWith(verts) val superstep_ = superstep // Create a read-only copy of superstep for capture in closure @@ -270,18 +271,21 @@ object Bagel extends Logging { } } +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Combiner[M, C] { def createCombiner(msg: M): C def mergeMsg(combiner: C, msg: M): C def mergeCombiners(a: C, b: C): C } +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Aggregator[V, A] { def createAggregator(vert: V): A def mergeAggregators(a: A, b: A): A } /** Default combiner that simply appends messages together (i.e. performs no aggregation) */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable { def createCombiner(msg: M): Array[M] = Array(msg) @@ -297,6 +301,7 @@ class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializab * Subclasses may store state along with each vertex and must * inherit from java.io.Serializable or scala.Serializable. */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Vertex { def active: Boolean } @@ -307,6 +312,7 @@ trait Vertex { * Subclasses may contain a payload to deliver to the target vertex * and must inherit from java.io.Serializable or scala.Serializable. */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Message[K] { def targetId: K } diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties deleted file mode 100644 index edbecdae92096..0000000000000 --- a/bagel/src/test/resources/log4j.properties +++ /dev/null @@ -1,27 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala deleted file mode 100644 index fb10d734ac74b..0000000000000 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.bagel - -import org.scalatest.{BeforeAndAfter, Assertions} -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ - -import org.apache.spark._ -import org.apache.spark.storage.StorageLevel - -class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable -class TestMessage(val targetId: String) extends Message[String] with Serializable - -class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - } - - test("halting by voting") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("halting by message silence") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0)))) - val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - val msgsOut = - msgs match { - case Some(ms) if (superstep < numSupersteps - 1) => - ms - case _ => - Array[TestMessage]() - } - (new TestVertex(self.active, self.age + 1), msgsOut) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("large number of iterations") { - // This tests whether jobs with a large number of iterations finish in a reasonable time, - // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang - failAfter(30 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 50 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } - - test("using non-default persistence level") { - failAfter(10 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 20 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } -} diff --git a/bin/pyspark b/bin/pyspark index 8f2a3b5a7717b..18012ee4a0b4f 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -65,7 +65,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="$SPARK_HOME/python/:$PYTHONPATH" -export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" +export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.9-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 3c6169983e76b..a97d884f0bf39 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.8.2.1-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.9-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/build/mvn b/build/mvn index ec0380afad319..7603ea03deb73 100755 --- a/build/mvn +++ b/build/mvn @@ -104,8 +104,8 @@ install_scala() { "scala-${scala_version}.tgz" \ "scala-${scala_version}/bin/scala" - SCALA_COMPILER="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-compiler.jar" - SCALA_LIBRARY="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-library.jar" + SCALA_COMPILER="$(cd "$(dirname "${scala_bin}")/../lib" && pwd)/scala-compiler.jar" + SCALA_LIBRARY="$(cd "$(dirname "${scala_bin}")/../lib" && pwd)/scala-library.jar" } # Setup healthy defaults for the Zinc port if none were provided from @@ -135,10 +135,10 @@ cd "${_CALLING_DIR}" # Now that zinc is ensured to be installed, check its status and, if its # not running or just installed, start it -if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status -port ${ZINC_PORT}`" ]; then +if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`"${ZINC_BIN}" -status -port ${ZINC_PORT}`" ]; then export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} - ${ZINC_BIN} -shutdown -port ${ZINC_PORT} - ${ZINC_BIN} -start -port ${ZINC_PORT} \ + "${ZINC_BIN}" -shutdown -port ${ZINC_PORT} + "${ZINC_BIN}" -start -port ${ZINC_PORT} \ -scala-compiler "${SCALA_COMPILER}" \ -scala-library "${SCALA_LIBRARY}" &>/dev/null fi diff --git a/build/sbt b/build/sbt index cc3203d79bccd..7d8d0993e57d8 100755 --- a/build/sbt +++ b/build/sbt @@ -20,10 +20,12 @@ # When creating new tests for Spark SQL Hive, the HADOOP_CLASSPATH must contain the hive jars so # that we can run Hive to generate the golden answer. This is not required for normal development # or testing. -for i in "$HIVE_HOME"/lib/* -do HADOOP_CLASSPATH="$HADOOP_CLASSPATH:$i" -done -export HADOOP_CLASSPATH +if [ -n "$HIVE_HOME" ]; then + for i in "$HIVE_HOME"/lib/* + do HADOOP_CLASSPATH="$HADOOP_CLASSPATH:$i" + done + export HADOOP_CLASSPATH +fi realpath () { ( diff --git a/conf/docker.properties.template b/conf/docker.properties.template index 26e3bfd9c5b9b..55cb094b4af46 100644 --- a/conf/docker.properties.template +++ b/conf/docker.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + spark.mesos.executor.docker.image: spark.mesos.executor.docker.volumes: /usr/local/lib:/host/usr/local/lib:ro spark.mesos.executor.home: /opt/spark diff --git a/conf/fairscheduler.xml.template b/conf/fairscheduler.xml.template index acf59e2a35986..385b2e772d2c8 100644 --- a/conf/fairscheduler.xml.template +++ b/conf/fairscheduler.xml.template @@ -1,4 +1,22 @@ + + + FAIR diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 74c5cea94403a..f3046be54d7c6 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 7f17bc7eea4f5..d6962e0da2f30 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # syntax: [instance].sink|source.[name].[options]=[value] # This file configures Spark's internal metrics system. The metrics system is diff --git a/conf/slaves.template b/conf/slaves.template index da0a01343d20a..be42a638230b7 100644 --- a/conf/slaves.template +++ b/conf/slaves.template @@ -1,2 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # A Spark Worker will be started on each of the machines listed below. localhost \ No newline at end of file diff --git a/conf/spark-defaults.conf.template b/conf/spark-defaults.conf.template index a48dcc70e1363..19cba6e71ed19 100644 --- a/conf/spark-defaults.conf.template +++ b/conf/spark-defaults.conf.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Default system properties included when running spark-submit. # This is useful for setting default environmental settings. diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index c05fe381a36a7..771251f90ee36 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -1,5 +1,22 @@ #!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # This file is sourced when running various Spark programs. # Copy it as spark-env.sh and edit that to configure Spark for your site. @@ -19,10 +36,10 @@ # Options read in YARN client mode # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files -# - SPARK_EXECUTOR_INSTANCES, Number of workers to start (Default: 2) -# - SPARK_EXECUTOR_CORES, Number of cores for the workers (Default: 1). -# - SPARK_EXECUTOR_MEMORY, Memory per Worker (e.g. 1000M, 2G) (Default: 1G) -# - SPARK_DRIVER_MEMORY, Memory for Master (e.g. 1000M, 2G) (Default: 1G) +# - SPARK_EXECUTOR_INSTANCES, Number of executors to start (Default: 2) +# - SPARK_EXECUTOR_CORES, Number of cores for the executors (Default: 1). +# - SPARK_EXECUTOR_MEMORY, Memory per Executor (e.g. 1000M, 2G) (Default: 1G) +# - SPARK_DRIVER_MEMORY, Memory for Driver (e.g. 1000M, 2G) (Default: 1G) # - SPARK_YARN_APP_NAME, The name of your application (Default: Spark) # - SPARK_YARN_QUEUE, The hadoop queue to use for allocation requests (Default: ‘default’) # - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job. diff --git a/core/pom.xml b/core/pom.xml index 4f79d71bf85fa..319a50049a82d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -272,6 +272,14 @@ org.apache.hadoop hadoop-client + + org.apache.curator + curator-client + + + org.apache.curator + curator-framework + org.apache.curator curator-recipes @@ -323,16 +331,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - org.apache.curator curator-test @@ -341,7 +339,7 @@ net.razorvine pyrolite - 4.4 + 4.9 net.razorvine @@ -352,7 +350,11 @@ net.sf.py4j py4j - 0.8.2.1 + 0.9 + + + org.apache.spark + spark-test-tags_${scala.binary.version} diff --git a/core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java b/core/src/main/java/org/apache/spark/annotation/AlphaComponent.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java rename to core/src/main/java/org/apache/spark/annotation/AlphaComponent.java diff --git a/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java b/core/src/main/java/org/apache/spark/annotation/DeveloperApi.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java rename to core/src/main/java/org/apache/spark/annotation/DeveloperApi.java diff --git a/core/src/main/scala/org/apache/spark/annotation/Experimental.java b/core/src/main/java/org/apache/spark/annotation/Experimental.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/Experimental.java rename to core/src/main/java/org/apache/spark/annotation/Experimental.java diff --git a/core/src/main/scala/org/apache/spark/annotation/Private.java b/core/src/main/java/org/apache/spark/annotation/Private.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/Private.java rename to core/src/main/java/org/apache/spark/annotation/Private.java diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 0b8b604e18494..ee82d679935c0 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -21,21 +21,30 @@ import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; +import javax.annotation.Nullable; +import scala.None$; +import scala.Option; import scala.Product2; import scala.Tuple2; import scala.collection.Iterator; +import com.google.common.annotations.VisibleForTesting; import com.google.common.io.Closeables; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.Partitioner; +import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -62,7 +71,7 @@ *

* There have been proposals to completely remove this code path; see SPARK-6026 for details. */ -final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter { +final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); @@ -72,31 +81,52 @@ final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter< private final BlockManager blockManager; private final Partitioner partitioner; private final ShuffleWriteMetrics writeMetrics; + private final int shuffleId; + private final int mapId; private final Serializer serializer; + private final IndexShuffleBlockResolver shuffleBlockResolver; /** Array of file writers, one for each partition */ private DiskBlockObjectWriter[] partitionWriters; + @Nullable private MapStatus mapStatus; + private long[] partitionLengths; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true + * and then call stop() with success = false if they get an exception, we want to make sure + * we don't try deleting files, etc twice. + */ + private boolean stopping = false; public BypassMergeSortShuffleWriter( - SparkConf conf, BlockManager blockManager, - Partitioner partitioner, - ShuffleWriteMetrics writeMetrics, - Serializer serializer) { + IndexShuffleBlockResolver shuffleBlockResolver, + BypassMergeSortShuffleHandle handle, + int mapId, + TaskContext taskContext, + SparkConf conf) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); - this.numPartitions = partitioner.numPartitions(); this.blockManager = blockManager; - this.partitioner = partitioner; - this.writeMetrics = writeMetrics; - this.serializer = serializer; + final ShuffleDependency dep = handle.dependency(); + this.mapId = mapId; + this.shuffleId = dep.shuffleId(); + this.partitioner = dep.partitioner(); + this.numPartitions = partitioner.numPartitions(); + this.writeMetrics = new ShuffleWriteMetrics(); + taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.serializer = Serializer.getSerializer(dep.serializer()); + this.shuffleBlockResolver = shuffleBlockResolver; } @Override - public void insertAll(Iterator> records) throws IOException { + public void write(Iterator> records) throws IOException { assert (partitionWriters == null); if (!records.hasNext()) { + partitionLengths = new long[numPartitions]; + shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -124,13 +154,24 @@ public void insertAll(Iterator> records) throws IOException { for (DiskBlockObjectWriter writer : partitionWriters) { writer.commitAndClose(); } + + partitionLengths = + writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId)); + shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } - @Override - public long[] writePartitionedFile( - BlockId blockId, - TaskContext context, - File outputFile) throws IOException { + @VisibleForTesting + long[] getPartitionLengths() { + return partitionLengths; + } + + /** + * Concatenate all of the per-partition files into a single combined file. + * + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). + */ + private long[] writePartitionedFile(File outputFile) throws IOException { // Track location of the partition starts in the output file final long[] lengths = new long[numPartitions]; if (partitionWriters == null) { @@ -151,7 +192,7 @@ public long[] writePartitionedFile( } finally { Closeables.close(in, copyThrewException); } - if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) { + if (!partitionWriters[i].fileSegment().file().delete()) { logger.error("Unable to delete file for partition {}", i); } } @@ -165,19 +206,33 @@ public long[] writePartitionedFile( } @Override - public void stop() throws IOException { - if (partitionWriters != null) { - try { - final DiskBlockManager diskBlockManager = blockManager.diskBlockManager(); - for (DiskBlockObjectWriter writer : partitionWriters) { - // This method explicitly does _not_ throw exceptions: - writer.revertPartialWritesAndClose(); - if (!diskBlockManager.getFile(writer.blockId()).delete()) { - logger.error("Error while deleting file for block {}", writer.blockId()); + public Option stop(boolean success) { + if (stopping) { + return None$.empty(); + } else { + stopping = true; + if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } + return Option.apply(mapStatus); + } else { + // The map task failed, so delete our output data. + if (partitionWriters != null) { + try { + for (DiskBlockObjectWriter writer : partitionWriters) { + // This method explicitly does _not_ throw exceptions: + File file = writer.revertPartialWritesAndClose(); + if (!file.delete()) { + logger.error("Error while deleting file {}", file.getAbsolutePath()); + } + } + } finally { + partitionWriters = null; } } - } finally { - partitionWriters = null; + shuffleBlockResolver.removeDataByMap(shuffleId, mapId); + return None$.empty(); } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java similarity index 98% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java rename to core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java index 4ee6a82c0423e..c11711966fa8c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; /** * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer. diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java similarity index 94% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 3d1ef0c48adc5..85fdaa8115fa3 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import javax.annotation.Nullable; import java.io.File; @@ -48,7 +48,7 @@ *

* Incoming records are appended to data pages. When all records have been inserted (or when the * current thread's shuffle memory limit is reached), the in-memory records are sorted according to - * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then + * their partition ids (using a {@link ShuffleInMemorySorter}). The sorted records are then * written to a single output file (or multiple files, if we've spilled). The format of the output * files is the same as the format of the final output file written by * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are @@ -59,9 +59,9 @@ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a * specialized merge procedure that avoids extra serialization/deserialization. */ -final class UnsafeShuffleExternalSorter { +final class ShuffleExternalSorter { - private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); + private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class); @VisibleForTesting static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; @@ -76,6 +76,10 @@ final class UnsafeShuffleExternalSorter { private final BlockManager blockManager; private final TaskContext taskContext; private final ShuffleWriteMetrics writeMetrics; + private long numRecordsInsertedSinceLastSpill = 0; + + /** Force this sorter to spill when there are this many elements in memory. For testing only */ + private final long numElementsForSpillThreshold; /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; @@ -94,12 +98,12 @@ final class UnsafeShuffleExternalSorter { private long peakMemoryUsedBytes; // These variables are reset after spilling: - @Nullable private UnsafeShuffleInMemorySorter inMemSorter; + @Nullable private ShuffleInMemorySorter inMemSorter; @Nullable private MemoryBlock currentPage = null; private long currentPagePosition = -1; private long freeSpaceInCurrentPage = 0; - public UnsafeShuffleExternalSorter( + public ShuffleExternalSorter( TaskMemoryManager memoryManager, ShuffleMemoryManager shuffleMemoryManager, BlockManager blockManager, @@ -117,11 +121,17 @@ public UnsafeShuffleExternalSorter( this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.numElementsForSpillThreshold = + conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); this.pageSizeBytes = (int) Math.min( PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes()); this.maxRecordSizeBytes = pageSizeBytes - 4; this.writeMetrics = writeMetrics; initializeForWriting(); + + // preserve first page to ensure that we have at least one page to work with. Otherwise, + // other operators in the same task may starve this sorter (SPARK-9709). + acquireNewPageIfNecessary(pageSizeBytes); } /** @@ -136,7 +146,8 @@ private void initializeForWriting() throws IOException { throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); } - this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize); + this.inMemSorter = new ShuffleInMemorySorter(initialSize); + numRecordsInsertedSinceLastSpill = 0; } /** @@ -162,7 +173,7 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } // This call performs the actual sort. - final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords = + final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords = inMemSorter.getSortedIterator(); // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this @@ -402,6 +413,10 @@ public void insertRecord( int lengthInBytes, int partitionId) throws IOException { + if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) { + spill(); + } + growPointerArrayIfNecessary(); // Need 4 bytes to store the record length. final int totalSpaceRequired = lengthInBytes + 4; @@ -449,6 +464,7 @@ public void insertRecord( recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, partitionId); + numRecordsInsertedSinceLastSpill += 1; } /** diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java similarity index 88% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index 5bab501da9364..a8dee6c6101c1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.util.Comparator; import org.apache.spark.util.collection.Sorter; -final class UnsafeShuffleInMemorySorter { +final class ShuffleInMemorySorter { private final Sorter sorter; private static final class SortComparator implements Comparator { @@ -44,10 +44,10 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { */ private int pointerArrayInsertPosition = 0; - public UnsafeShuffleInMemorySorter(int initialSize) { + public ShuffleInMemorySorter(int initialSize) { assert (initialSize > 0); this.pointerArray = new long[initialSize]; - this.sorter = new Sorter(UnsafeShuffleSortDataFormat.INSTANCE); + this.sorter = new Sorter(ShuffleSortDataFormat.INSTANCE); } public void expandPointerArray() { @@ -92,14 +92,14 @@ public void insertRecord(long recordPointer, int partitionId) { /** * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining. */ - public static final class UnsafeShuffleSorterIterator { + public static final class ShuffleSorterIterator { private final long[] pointerArray; private final int numRecords; final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); private int position = 0; - public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) { + public ShuffleSorterIterator(int numRecords, long[] pointerArray) { this.numRecords = numRecords; this.pointerArray = pointerArray; } @@ -117,8 +117,8 @@ public void loadNext() { /** * Return an iterator over record pointers in sorted order. */ - public UnsafeShuffleSorterIterator getSortedIterator() { + public ShuffleSorterIterator getSortedIterator() { sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR); - return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); + return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java similarity index 86% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java rename to core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index a66d74ee44782..8a1e5aec6ff0e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import org.apache.spark.util.collection.SortDataFormat; -final class UnsafeShuffleSortDataFormat extends SortDataFormat { +final class ShuffleSortDataFormat extends SortDataFormat { - public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat(); + public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat(); - private UnsafeShuffleSortDataFormat() { } + private ShuffleSortDataFormat() { } @Override public PackedRecordPointer getKey(long[] data, int pos) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java deleted file mode 100644 index 656ea0401a144..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.sort; - -import java.io.File; -import java.io.IOException; - -import scala.Product2; -import scala.collection.Iterator; - -import org.apache.spark.annotation.Private; -import org.apache.spark.TaskContext; -import org.apache.spark.storage.BlockId; - -/** - * Interface for objects that {@link SortShuffleWriter} uses to write its output files. - */ -@Private -public interface SortShuffleFileWriter { - - void insertAll(Iterator> records) throws IOException; - - /** - * Write all the data added into this shuffle sorter into a file in the disk store. This is - * called by the SortShuffleWriter and can go through an efficient path of just concatenating - * binary files if we decided to avoid merge-sorting. - * - * @param blockId block ID to write to. The index file will be blockId.name + ".index". - * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. - * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) - */ - long[] writePartitionedFile( - BlockId blockId, - TaskContext context, - File outputFile) throws IOException; - - void stop() throws IOException; -} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java similarity index 90% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java rename to core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java index 7bac0dc0bbeb6..df9f7b7abe028 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import java.io.File; import org.apache.spark.storage.TempShuffleBlockId; /** - * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}. + * Metadata for a block of data written by {@link ShuffleExternalSorter}. */ final class SpillInfo { final long[] partitionLengths; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java similarity index 98% rename from core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java rename to core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index fdb309e365f69..e8f050cb2dab1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.unsafe; +package org.apache.spark.shuffle.sort; import javax.annotation.Nullable; import java.io.*; @@ -80,7 +80,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final boolean transferToEnabled; @Nullable private MapStatus mapStatus; - @Nullable private UnsafeShuffleExternalSorter sorter; + @Nullable private ShuffleExternalSorter sorter; private long peakMemoryUsedBytes = 0; /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ @@ -104,15 +104,15 @@ public UnsafeShuffleWriter( IndexShuffleBlockResolver shuffleBlockResolver, TaskMemoryManager memoryManager, ShuffleMemoryManager shuffleMemoryManager, - UnsafeShuffleHandle handle, + SerializedShuffleHandle handle, int mapId, TaskContext taskContext, SparkConf sparkConf) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); - if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) { + if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( "UnsafeShuffleWriter can only be used for shuffles with at most " + - UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions"); + SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + " reduce partitions"); } this.blockManager = blockManager; this.shuffleBlockResolver = shuffleBlockResolver; @@ -195,7 +195,7 @@ public void write(scala.collection.Iterator> records) throws IOEx private void open() throws IOException { assert (sorter == null); - sorter = new UnsafeShuffleExternalSorter( + sorter = new ShuffleExternalSorter( memoryManager, shuffleMemoryManager, blockManager, diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java index a90cc0e761f62..40b5fb7fe4b49 100644 --- a/core/src/main/java/org/apache/spark/util/collection/TimSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java @@ -15,6 +15,24 @@ * limitations under the License. */ +/* + * Based on TimSort.java from the Android Open Source Project + * + * Copyright (C) 2008 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.util.collection; import java.util.Comparator; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 71b76d5ddfaa7..d2bf297c6c178 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -21,6 +21,7 @@ import org.apache.spark.annotation.Private; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.ByteArray; import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.util.Utils; @@ -62,21 +63,7 @@ public int compare(long aPrefix, long bPrefix) { } public static long computePrefix(byte[] bytes) { - if (bytes == null) { - return 0L; - } else { - /** - * TODO: If a wrapper for BinaryType is created (SPARK-8786), - * these codes below will be in the wrapper class. - */ - final int minLen = Math.min(bytes.length, 8); - long p = 0; - for (int i = 0; i < minLen; ++i) { - p |= (128L + Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i)) - << (56 - 8 * i); - } - return p; - } + return ByteArray.getPrefix(bytes); } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index fc364e0a895b1..0a311d2d935ac 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -160,15 +160,14 @@ public BoxedUnit apply() { * Allocates new sort data structures. Called when creating the sorter and after each spill. */ private void initializeForWriting() throws IOException { + // Note: Do not track memory for the pointer array for now because of SPARK-10474. + // In more detail, in TungstenAggregate we only reserve a page, but when we fall back to + // sort-based aggregation we try to acquire a page AND a pointer array, which inevitably + // fails if all other memory is already occupied. It should be safe to not track the array + // because its memory footprint is frequently much smaller than that of a page. This is a + // temporary hack that we should address in 1.6.0. + // TODO: track the pointer array memory! this.writeMetrics = new ShuffleWriteMetrics(); - final long pointerArrayMemory = - UnsafeInMemorySorter.getMemoryRequirementsForPointerArray(initialSize); - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pointerArrayMemory); - if (memoryAcquired != pointerArrayMemory) { - shuffleMemoryManager.release(memoryAcquired); - throw new IOException("Could not acquire " + pointerArrayMemory + " bytes of memory"); - } - this.inMemSorter = new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize); this.isInMemSorterExternal = false; @@ -265,14 +264,7 @@ private long freeMemory() { shuffleMemoryManager.release(block.size()); memoryFreed += block.size(); } - if (inMemSorter != null) { - if (!isInMemSorterExternal) { - long sorterMemoryUsage = inMemSorter.getMemoryUsage(); - memoryFreed += sorterMemoryUsage; - shuffleMemoryManager.release(sorterMemoryUsage); - } - inMemSorter = null; - } + // TODO: track in-memory sorter memory usage (SPARK-10474) allocatedPages.clear(); currentPage = null; currentPagePosition = -1; @@ -310,17 +302,8 @@ public void cleanupResources() { private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { - logger.debug("Attempting to expand sort pointer array"); - final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage(); - final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); - if (memoryAcquired < memoryToGrowPointerArray) { - shuffleMemoryManager.release(memoryAcquired); - spill(); - } else { - inMemSorter.expandPointerArray(); - shuffleMemoryManager.release(oldPointerArrayMemoryUsage); - } + // TODO: track the pointer array memory! (SPARK-10474) + inMemSorter.expandPointerArray(); } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 4989b05d63e23..501dfe77d13cb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -24,12 +24,15 @@ import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description * of the file format). */ final class UnsafeSorterSpillReader extends UnsafeSorterIterator { + private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class); private final File file; private InputStream in; @@ -73,7 +76,9 @@ public void loadNext() throws IOException { numRecordsRemaining--; if (numRecordsRemaining == 0) { in.close(); - file.delete(); + if (!file.delete() && file.exists()) { + logger.warn("Unable to delete spill file {}", file.getPath()); + } in = null; din = null; } diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties index 689afea64f8db..c85abc35b93bf 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=WARN, console log4j.appender.console=org.apache.log4j.ConsoleAppender diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 27006e45e932b..d44cc85dcbd82 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index c39c8667d013e..5592b75afb75b 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -47,7 +47,7 @@ import org.apache.spark.util.Utils * @tparam T partial data that can be added in */ class Accumulable[R, T] private[spark] ( - @transient initialValue: R, + initialValue: R, param: AccumulableParam[R, T], val name: Option[String], internal: Boolean) diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 289aab9bd9e51..7196e57d5d2e2 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap} +import org.apache.spark.util.collection.ExternalAppendOnlyMap /** * :: DeveloperApi :: @@ -34,59 +34,30 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - // When spilling is enabled sorting will happen externally, but not necessarily with an - // ExternalSorter. - private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) - @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]): Iterator[(K, C)] = combineValuesByKey(iter, null) - def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], - context: TaskContext): Iterator[(K, C)] = { - if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K, C] - var kv: Product2[K, V] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) - } - while (iter.hasNext) { - kv = iter.next() - combiners.changeValue(kv._1, update) - } - combiners.iterator - } else { - val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) - combiners.insertAll(iter) - updateMetrics(context, combiners) - combiners.iterator - } + def combineValuesByKey( + iter: Iterator[_ <: Product2[K, V]], + context: TaskContext): Iterator[(K, C)] = { + val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) + combiners.insertAll(iter) + updateMetrics(context, combiners) + combiners.iterator } @deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0") def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] = combineCombinersByKey(iter, null) - def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext) - : Iterator[(K, C)] = - { - if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K, C] - var kc: Product2[K, C] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 - } - while (iter.hasNext) { - kc = iter.next() - combiners.changeValue(kc._1, update) - } - combiners.iterator - } else { - val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) - combiners.insertAll(iter) - updateMetrics(context, combiners) - combiners.iterator - } + def combineCombinersByKey( + iter: Iterator[_ <: Product2[K, C]], + context: TaskContext): Iterator[(K, C)] = { + val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) + combiners.insertAll(iter) + updateMetrics(context, combiners) + combiners.iterator } /** Update task metrics after populating the external map. */ diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index fc8cdde9348ee..9aafc9eb1cde7 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -17,6 +17,8 @@ package org.apache.spark +import scala.reflect.ClassTag + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer @@ -65,8 +67,8 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { * @param mapSideCombine whether to perform partial aggregation (also known as map-side combine) */ @DeveloperApi -class ShuffleDependency[K, V, C]( - @transient _rdd: RDD[_ <: Product2[K, V]], +class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( + @transient private val _rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, @@ -76,6 +78,13 @@ class ShuffleDependency[K, V, C]( override def rdd: RDD[Product2[K, V]] = _rdd.asInstanceOf[RDD[Product2[K, V]]] + private[spark] val keyClassName: String = reflect.classTag[K].runtimeClass.getName + private[spark] val valueClassName: String = reflect.classTag[V].runtimeClass.getName + // Note: It's possible that the combiner class tag is null, if the combineByKey + // methods in PairRDDFunctions are used instead of combineByKeyWithClassTag. + private[spark] val combinerClassName: Option[String] = + Option(reflect.classTag[C]).map(_.runtimeClass.getName) + val shuffleId: Int = _rdd.context.newShuffleId() val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index ee60d697d8799..1f1f0b75de5f1 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable +import scala.concurrent.Future import org.apache.spark.executor.TaskMetrics import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} @@ -147,11 +148,31 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) } } + /** + * Send ExecutorRegistered to the event loop to add a new executor. Only for test. + * + * @return if HeartbeatReceiver is stopped, return None. Otherwise, return a Some(Future) that + * indicate if this operation is successful. + */ + def addExecutor(executorId: String): Option[Future[Boolean]] = { + Option(self).map(_.ask[Boolean](ExecutorRegistered(executorId))) + } + /** * If the heartbeat receiver is not stopped, notify it of executor registrations. */ override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { - Option(self).foreach(_.ask[Boolean](ExecutorRegistered(executorAdded.executorId))) + addExecutor(executorAdded.executorId) + } + + /** + * Send ExecutorRemoved to the event loop to remove a executor. Only for test. + * + * @return if HeartbeatReceiver is stopped, return None. Otherwise, return a Some(Future) that + * indicate if this operation is successful. + */ + def removeExecutor(executorId: String): Option[Future[Boolean]] = { + Option(self).map(_.ask[Boolean](ExecutorRemoved(executorId))) } /** @@ -165,7 +186,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) * and expire it with loud error messages. */ override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { - Option(self).foreach(_.ask[Boolean](ExecutorRemoved(executorRemoved.executorId))) + removeExecutor(executorRemoved.executorId) } private def expireDeadHosts(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index f0598816d6c07..69f6e06ee0057 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -21,11 +21,10 @@ import org.apache.log4j.{LogManager, PropertyConfigurator} import org.slf4j.{Logger, LoggerFactory} import org.slf4j.impl.StaticLoggerBinder -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.Private import org.apache.spark.util.Utils /** - * :: DeveloperApi :: * Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows * logging messages at different levels using methods that only evaluate parameters lazily if the * log level is enabled. @@ -33,7 +32,7 @@ import org.apache.spark.util.Utils * NOTE: DO NOT USE this class outside of Spark. It is intended as an internal utility. * This will likely be changed or removed in future releases. */ -@DeveloperApi +@Private trait Logging { // Make the log field transient so that objects with Logging can // be serialized and used on another machine diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala new file mode 100644 index 0000000000000..f8a6f1d0d8cbb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Holds statistics about the output sizes in a map stage. May become a DeveloperApi in the future. + * + * @param shuffleId ID of the shuffle + * @param bytesByPartitionId approximate number of output bytes for each map output partition + * (may be inexact due to use of compressed map statuses) + */ +private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long]) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index a387592783850..72355cdfa68b3 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io._ +import java.util.Arrays import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} @@ -44,10 +45,10 @@ private[spark] class MapOutputTrackerMasterEndpoint( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => - val hostPort = context.sender.address.hostPort + val hostPort = context.senderAddress.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) - val serializedSize = mapOutputStatuses.size + val serializedSize = mapOutputStatuses.length if (serializedSize > maxAkkaFrameSize) { val msg = s"Map output statuses were $serializedSize bytes which " + s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)." @@ -132,13 +133,57 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { - logDebug(s"Fetching outputs for shuffle $shuffleId, reduce $reduceId") - val startTime = System.currentTimeMillis + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) + } + /** + * Called from executors to get the server URIs and output sizes for each shuffle block that + * needs to be read from a given range of map output partitions (startPartition is included but + * endPartition is excluded from the range). + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block id, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. + */ + def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") + val statuses = getStatuses(shuffleId) + // Synchronize on the returned array because, on the driver, it gets mutated in place + statuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + } + } + + /** + * Return statistics about all of the outputs for a given shuffle. + */ + def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { + val statuses = getStatuses(dep.shuffleId) + // Synchronize on the returned array because, on the driver, it gets mutated in place + statuses.synchronized { + val totalSizes = new Array[Long](dep.partitioner.numPartitions) + for (s <- statuses) { + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + new MapOutputStatistics(dep.shuffleId, totalSizes) + } + } + + /** + * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * on this array when reading it, because on the driver, we may be changing it in place. + * + * (It would be nice to remove this restriction in the future.) + */ + private def getStatuses(shuffleId: Int): Array[MapStatus] = { val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + val startTime = System.currentTimeMillis var fetchedStatuses: Array[MapStatus] = null fetching.synchronized { // Someone else is fetching it; wait for them to be done @@ -160,7 +205,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } if (fetchedStatuses == null) { - // We won the race to fetch the output locs; do so + // We won the race to fetch the statuses; do so logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { @@ -175,22 +220,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } } - logDebug(s"Fetching map output location for shuffle $shuffleId, reduce $reduceId took " + + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + s"${System.currentTimeMillis - startTime} ms") if (fetchedStatuses != null) { - fetchedStatuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) - } + return fetchedStatuses } else { logError("Missing all output locations for shuffle " + shuffleId) throw new MetadataFetchFailedException( - shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId) + shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) } } else { - statuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) - } + return statuses } } @@ -235,6 +276,21 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) /** Cache a serialized version of the output statuses for each shuffle to send them out faster */ private var cacheEpoch = epoch + /** Whether to compute locality preferences for reduce tasks */ + private val shuffleLocalityEnabled = conf.getBoolean("spark.shuffle.reduceLocality.enabled", true) + + // Number of map and reduce tasks above which we do not assign preferred locations based on map + // output sizes. We limit the size of jobs for which assign preferred locations as computing the + // top locations by size becomes expensive. + private val SHUFFLE_PREF_MAP_THRESHOLD = 1000 + // NOTE: This should be less than 2000 as we use HighlyCompressedMapStatus beyond that + private val SHUFFLE_PREF_REDUCE_THRESHOLD = 1000 + + // Fraction of total map output that must be at a location for it to considered as a preferred + // location for a reduce task. Making this larger will focus on fewer locations where most data + // can be read locally, but may lead to more delay in scheduling if those locations are busy. + private val REDUCER_PREF_LOCS_FRACTION = 0.2 + /** * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the driver, * so that statuses are dropped only by explicit de-registering or by TTL-based cleaning (if set). @@ -295,6 +351,30 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) } + /** + * Return the preferred hosts on which to run the given map output partition in a given shuffle, + * i.e. the nodes that the most outputs for that partition are on. + * + * @param dep shuffle dependency object + * @param partitionId map output partition that we want to read + * @return a sequence of host names + */ + def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], partitionId: Int) + : Seq[String] = { + if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD && + dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) { + val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId, + dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION) + if (blockManagerIds.nonEmpty) { + blockManagerIds.get.map(_.host) + } else { + Nil + } + } else { + Nil + } + } + /** * Return a list of locations that each have fraction of map output greater than the specified * threshold. @@ -433,23 +513,25 @@ private[spark] object MapOutputTracker extends Logging { } /** - * Converts an array of MapStatuses for a given reduce ID to a sequence that, for each block - * manager ID, lists the shuffle block ids and corresponding shuffle block sizes stored at that - * block manager. + * Given an array of map statuses and a range of map output partitions, returns a sequence that, + * for each block manager ID, lists the shuffle block IDs and corresponding shuffle block sizes + * stored at that block manager. * * If any of the statuses is null (indicating a missing location due to a failed mapper), * throws a FetchFailedException. * * @param shuffleId Identifier for the shuffle - * @param reduceId Identifier for the reduce task + * @param startPartition Start of map output partition ID range (included in range) + * @param endPartition End of map output partition ID range (excluded from range) * @param statuses List of map statuses, indexed by map ID. * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, - * and the second item is a sequence of (shuffle block id, shuffle block size) tuples + * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ private def convertMapStatuses( shuffleId: Int, - reduceId: Int, + startPartition: Int, + endPartition: Int, statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] @@ -457,10 +539,12 @@ private[spark] object MapOutputTracker extends Logging { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" logError(errorMessage) - throw new MetadataFetchFailedException(shuffleId, reduceId, errorMessage) + throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) } else { - splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += - ((ShuffleBlockId(shuffleId, mapId, reduceId), status.getSizeForBlock(reduceId))) + for (part <- startPartition until endPartition) { + splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part))) + } } } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 29e581bb57cbc..e4df7af81a6d2 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -104,8 +104,8 @@ class HashPartitioner(partitions: Int) extends Partitioner { * the value of `partitions`. */ class RangePartitioner[K : Ordering : ClassTag, V]( - @transient partitions: Int, - @transient rdd: RDD[_ <: Product2[K, V]], + partitions: Int, + rdd: RDD[_ <: Product2[K, V]], private var ascending: Boolean = true) extends Partitioner { diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index b344b5e173d67..58d3b846fd80d 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -418,16 +418,35 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } // Validate memory fractions - val memoryKeys = Seq( + val deprecatedMemoryKeys = Seq( "spark.storage.memoryFraction", "spark.shuffle.memoryFraction", "spark.shuffle.safetyFraction", "spark.storage.unrollFraction", "spark.storage.safetyFraction") + val memoryKeys = Seq( + "spark.memory.fraction", + "spark.memory.storageFraction") ++ + deprecatedMemoryKeys for (key <- memoryKeys) { val value = getDouble(key, 0.5) if (value > 1 || value < 0) { - throw new IllegalArgumentException("$key should be between 0 and 1 (was '$value').") + throw new IllegalArgumentException(s"$key should be between 0 and 1 (was '$value').") + } + } + + // Warn against deprecated memory fractions (unless legacy memory management mode is enabled) + val legacyMemoryManagementKey = "spark.memory.useLegacyMode" + val legacyMemoryManagement = getBoolean(legacyMemoryManagementKey, false) + if (!legacyMemoryManagement) { + val keyset = deprecatedMemoryKeys.toSet + val detected = settings.keys().asScala.filter(keyset.contains) + if (detected.nonEmpty) { + logWarning("Detected deprecated memory fraction settings: " + + detected.mkString("[", ", ", "]") + ". As of Spark 1.6, execution and storage " + + "memory management are unified. All memory fractions used in the old model are " + + "now deprecated and no longer read. If you wish to use the old memory management, " + + s"you may explicitly enable `$legacyMemoryManagementKey` (not recommended).") } } @@ -576,7 +595,9 @@ private[spark] object SparkConf extends Logging { "spark.rpc.lookupTimeout" -> Seq( AlternateConfig("spark.akka.lookupTimeout", "1.4")), "spark.streaming.fileStream.minRememberDuration" -> Seq( - AlternateConfig("spark.streaming.minRememberDuration", "1.5")) + AlternateConfig("spark.streaming.minRememberDuration", "1.5")), + "spark.yarn.max.executor.failures" -> Seq( + AlternateConfig("spark.yarn.max.worker.failures", "1.5")) ) /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 738887076b0d1..a6857b4c7d882 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -33,6 +33,7 @@ import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} import scala.util.control.NonFatal +import org.apache.commons.lang.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, @@ -89,14 +90,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // NOTE: this must be placed at the beginning of the SparkContext constructor. SparkContext.markPartiallyConstructed(this, allowMultipleContexts) - // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, - // etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It - // contains a map from hostname to a list of input format splits on the host. - private[spark] var preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map() - val startTime = System.currentTimeMillis() - private val stopped: AtomicBoolean = new AtomicBoolean(false) + private[spark] val stopped: AtomicBoolean = new AtomicBoolean(false) private def assertNotStopped(): Unit = { if (stopped.get()) { @@ -115,16 +111,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Alternative constructor for setting preferred locations where Spark will create executors. * * @param config a [[org.apache.spark.SparkConf]] object specifying other Spark parameters - * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. - * Can be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] - * from a list of input files or InputFormats for the application. + * @param preferredNodeLocationData not used. Left for backward compatibility. */ @deprecated("Passing in preferred locations has no effect at all, see SPARK-8949", "1.5.0") @DeveloperApi def this(config: SparkConf, preferredNodeLocationData: Map[String, Set[SplitInfo]]) = { this(config) logWarning("Passing in preferred locations has no effect at all, see SPARK-8949") - this.preferredNodeLocationData = preferredNodeLocationData } /** @@ -146,10 +139,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. * @param environment Environment variables to set on worker nodes. - * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. - * Can be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] - * from a list of input files or InputFormats for the application. + * @param preferredNodeLocationData not used. Left for backward compatibility. */ + @deprecated("Passing in preferred locations has no effect at all, see SPARK-10921", "1.6.0") def this( master: String, appName: String, @@ -162,7 +154,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (preferredNodeLocationData.nonEmpty) { logWarning("Passing in preferred locations has no effect at all, see SPARK-8949") } - this.preferredNodeLocationData = preferredNodeLocationData } // NOTE: The below constructors could be consolidated using default arguments. Due to @@ -176,7 +167,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @param appName A name for your application, to display on the cluster web UI. */ private[spark] def this(master: String, appName: String) = - this(master, appName, null, Nil, Map(), Map()) + this(master, appName, null, Nil, Map()) /** * Alternative constructor that allows setting common Spark properties directly @@ -186,7 +177,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @param sparkHome Location where Spark is installed on cluster nodes. */ private[spark] def this(master: String, appName: String, sparkHome: String) = - this(master, appName, sparkHome, Nil, Map(), Map()) + this(master, appName, sparkHome, Nil, Map()) /** * Alternative constructor that allows setting common Spark properties directly @@ -198,7 +189,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * system or HDFS, HTTP, HTTPS, or FTP URLs. */ private[spark] def this(master: String, appName: String, sparkHome: String, jars: Seq[String]) = - this(master, appName, sparkHome, jars, Map(), Map()) + this(master, appName, sparkHome, jars, Map()) // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") @@ -265,6 +256,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def isLocal: Boolean = (master == "local" || master.startsWith("local[")) + /** + * @return true if context is stopped or in the midst of stopping. + */ + def isStopped: Boolean = stopped.get() + // An asynchronous listener bus for Spark events private[spark] val listenerBus = new LiveListenerBus @@ -273,7 +269,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli conf: SparkConf, isLocal: Boolean, listenerBus: LiveListenerBus): SparkEnv = { - SparkEnv.createDriverEnv(conf, isLocal, listenerBus) + SparkEnv.createDriverEnv(conf, isLocal, listenerBus, SparkContext.numDriverCores(master)) } private[spark] def env: SparkEnv = _env @@ -347,8 +343,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] var checkpointDir: Option[String] = None // Thread Local variable that can be used by users to pass information down the stack - private val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) + protected[spark] val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = { + // Note: make a clone such that changes in the parent properties aren't reflected in + // the those of the children threads, which has confusing semantics (SPARK-10563). + SerializationUtils.clone(parent).asInstanceOf[Properties] + } override protected def initialValue(): Properties = new Properties() } @@ -516,6 +516,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _applicationId = _taskScheduler.applicationId() _applicationAttemptId = taskScheduler.applicationAttemptId() _conf.set("spark.app.id", _applicationId) + _ui.foreach(_.setAppId(_applicationId)) _env.blockManager.initialize(_applicationId) // The metrics system for Driver need to be set spark.app.id to app ID. @@ -858,7 +859,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that wholeTextFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = job.getConfiguration + val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new WholeTextFileRDD( this, classOf[WholeTextFileInputFormat], @@ -910,7 +911,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that binaryFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = job.getConfiguration + val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new BinaryFileRDD( this, classOf[StreamInputFormat], @@ -1092,7 +1093,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that newAPIHadoopFile aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updatedConf = job.getConfiguration + val updatedConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf).setName(path) } @@ -1516,8 +1517,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ @DeveloperApi def getRDDStorageInfo: Array[RDDInfo] = { + getRDDStorageInfo(_ => true) + } + + private[spark] def getRDDStorageInfo(filter: RDD[_] => Boolean): Array[RDDInfo] = { assertNotStopped() - val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray + val rddInfos = persistentRdds.values.filter(filter).map(RDDInfo.fromRdd).toArray StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) rddInfos.filter(_.isCached) } @@ -1741,6 +1746,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } SparkEnv.set(null) } + // Unset YARN mode system env variable, to allow switching between cluster types. + System.clearProperty("SPARK_YARN_MODE") SparkContext.clearActiveContext() logInfo("Successfully stopped SparkContext") } @@ -1980,6 +1987,23 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli new SimpleFutureAction(waiter, resultFunc) } + /** + * Submit a map stage for execution. This is currently an internal API only, but might be + * promoted to DeveloperApi in the future. + */ + private[spark] def submitMapStage[K, V, C](dependency: ShuffleDependency[K, V, C]) + : SimpleFutureAction[MapOutputStatistics] = { + assertNotStopped() + val callSite = getCallSite() + var result: MapOutputStatistics = null + val waiter = dagScheduler.submitMapStage( + dependency, + (r: MapOutputStatistics) => { result = r }, + callSite, + localProperties.get) + new SimpleFutureAction[MapOutputStatistics](waiter, result) + } + /** * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]] * for more information. @@ -2536,6 +2560,21 @@ object SparkContext extends Logging { res } + /** + * The number of driver cores to use for execution in local mode, 0 otherwise. + */ + private[spark] def numDriverCores(master: String): Int = { + def convertToInt(threads: String): Int = { + if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt + } + master match { + case "local" => 1 + case SparkMasterRegex.LOCAL_N_REGEX(threads) => convertToInt(threads) + case SparkMasterRegex.LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads) + case _ => 0 // driver is not used for execution + } + } + /** * Create a task scheduler based on a given master URL. * Return a 2-tuple of the scheduler backend and the task scheduler. @@ -2543,18 +2582,7 @@ object SparkContext extends Logging { private def createTaskScheduler( sc: SparkContext, master: String): (SchedulerBackend, TaskScheduler) = { - // Regular expression used for local[N] and local[*] master formats - val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r - // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r - // Regular expression for simulating a Spark cluster of [N, cores, memory] locally - val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r - // Regular expression for connecting to Spark deploy clusters - val SPARK_REGEX = """spark://(.*)""".r - // Regular expression for connection to Mesos cluster by mesos:// or zk:// url - val MESOS_REGEX = """(mesos|zk)://.*""".r - // Regular expression for connection to Simr cluster - val SIMR_REGEX = """simr://(.*)""".r + import SparkMasterRegex._ // When running locally, don't try to re-execute tasks on failure. val MAX_LOCAL_TASK_FAILURES = 1 @@ -2695,6 +2723,24 @@ object SparkContext extends Logging { } } +/** + * A collection of regexes for extracting information from the master string. + */ +private object SparkMasterRegex { + // Regular expression used for local[N] and local[*] master formats + val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r + // Regular expression for local[N, maxRetries], used in tests with failing tasks + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r + // Regular expression for simulating a Spark cluster of [N, cores, memory] locally + val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r + // Regular expression for connecting to Spark deploy clusters + val SPARK_REGEX = """spark://(.*)""".r + // Regular expression for connection to Mesos cluster by mesos:// or zk:// url + val MESOS_REGEX = """(mesos|zk)://.*""".r + // Regular expression for connection to Simr cluster + val SIMR_REGEX = """simr://(.*)""".r +} + /** * A class encapsulating how to convert some type T to Writable. It stores both the Writable class * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 0f1e2e069568d..b5c35c569e45f 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -20,20 +20,19 @@ package org.apache.spark import java.io.File import java.net.Socket -import akka.actor.ActorSystem - import scala.collection.mutable import scala.util.Properties +import akka.actor.ActorSystem import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.memory.{MemoryManager, StaticMemoryManager, UnifiedMemoryManager} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService -import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} @@ -42,7 +41,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator} -import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} /** * :: DeveloperApi :: @@ -58,6 +57,7 @@ import org.apache.spark.util.{RpcUtils, Utils} class SparkEnv ( val executorId: String, private[spark] val rpcEnv: RpcEnv, + _actorSystem: ActorSystem, // TODO Remove actorSystem val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -70,6 +70,8 @@ class SparkEnv ( val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, + // TODO: unify these *MemoryManager classes (SPARK-10984) + val memoryManager: MemoryManager, val shuffleMemoryManager: ShuffleMemoryManager, val executorMemoryManager: ExecutorMemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, @@ -77,7 +79,7 @@ class SparkEnv ( // TODO Remove actorSystem @deprecated("Actor system is no longer supported as of 1.4.0", "1.4.0") - val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + val actorSystem: ActorSystem = _actorSystem private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -101,6 +103,9 @@ class SparkEnv ( blockManager.master.stop() metricsSystem.stop() outputCommitCoordinator.stop() + if (!rpcEnv.isInstanceOf[AkkaRpcEnv]) { + actorSystem.shutdown() + } rpcEnv.shutdown() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut @@ -185,6 +190,7 @@ object SparkEnv extends Logging { conf: SparkConf, isLocal: Boolean, listenerBus: LiveListenerBus, + numCores: Int, mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!") assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!") @@ -197,6 +203,7 @@ object SparkEnv extends Logging { port, isDriver = true, isLocal = isLocal, + numUsableCores = numCores, listenerBus = listenerBus, mockOutputCommitCoordinator = mockOutputCommitCoordinator ) @@ -236,8 +243,8 @@ object SparkEnv extends Logging { port: Int, isDriver: Boolean, isLocal: Boolean, + numUsableCores: Int, listenerBus: LiveListenerBus = null, - numUsableCores: Int = 0, mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { // Listener bus is only used on the driver @@ -250,7 +257,13 @@ object SparkEnv extends Logging { // Create the ActorSystem for Akka and get the port it binds to. val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager) - val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + val actorSystem: ActorSystem = + if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { + rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + } else { + // Create a ActorSystem for legacy codes + AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)._1 + } // Figure out which port Akka actually bound to in case the original port is 0 or occupied. if (isDriver) { @@ -319,23 +332,23 @@ object SparkEnv extends Logging { val shortShuffleMgrNames = Map( "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager", - "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") + "tungsten-sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) - val shuffleMemoryManager = ShuffleMemoryManager.create(conf, numUsableCores) - - val blockTransferService = - conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { - case "netty" => - new NettyBlockTransferService(conf, securityManager, numUsableCores) - case "nio" => - logWarning("NIO-based block transfer service is deprecated, " + - "and will be removed in Spark 1.6.0.") - new NioBlockTransferService(conf, securityManager) + val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false) + val memoryManager: MemoryManager = + if (useLegacyMemoryManager) { + new StaticMemoryManager(conf) + } else { + new UnifiedMemoryManager(conf) } + val shuffleMemoryManager = ShuffleMemoryManager.create(conf, memoryManager, numUsableCores) + + val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores) + val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( BlockManagerMaster.DRIVER_ENDPOINT_NAME, new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)), @@ -343,8 +356,8 @@ object SparkEnv extends Logging { // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster, - serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, - numUsableCores) + serializer, conf, memoryManager, mapOutputTracker, shuffleManager, + blockTransferService, securityManager, numUsableCores) val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) @@ -404,6 +417,7 @@ object SparkEnv extends Logging { val envInstance = new SparkEnv( executorId, rpcEnv, + actorSystem, serializer, closureSerializer, cacheManager, @@ -416,6 +430,7 @@ object SparkEnv extends Logging { httpFileServer, sparkFilesDir, metricsSystem, + memoryManager, shuffleMemoryManager, executorMemoryManager, outputCommitCoordinator, diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index f5dd36cbcfe6d..ac6eaab20d8d2 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -37,7 +37,7 @@ import org.apache.spark.util.SerializableJobConf * a filename to write to, etc, exactly like in a Hadoop MapReduce job. */ private[spark] -class SparkHadoopWriter(@transient jobConf: JobConf) +class SparkHadoopWriter(jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable { @@ -104,8 +104,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) } def commit() { - SparkHadoopMapRedUtil.commitTask( - getOutputCommitter(), getTaskContext(), jobID, splitID, attemptID) + SparkHadoopMapRedUtil.commitTask(getOutputCommitter(), getTaskContext(), jobID, splitID) } def commitJob() { diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 934d00dc708b9..9335c5f4160bf 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -17,13 +17,17 @@ package org.apache.spark -import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import java.io.{ObjectInputStream, ObjectOutputStream} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils +// ============================================================================================== +// NOTE: new task end reasons MUST be accompanied with serialization logic in util.JsonProtocol! +// ============================================================================================== + /** * :: DeveloperApi :: * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry @@ -48,6 +52,8 @@ case object Success extends TaskEndReason sealed trait TaskFailedReason extends TaskEndReason { /** Error message displayed in the web UI. */ def toErrorString: String + + def shouldEventuallyFailJob: Boolean = true } /** @@ -191,9 +197,18 @@ case object TaskKilled extends TaskFailedReason { * Task requested the driver to commit, but was denied. */ @DeveloperApi -case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason { +case class TaskCommitDenied( + jobID: Int, + partitionID: Int, + attemptNumber: Int) extends TaskFailedReason { override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" + - s" for job: $jobID, partition: $partitionID, attempt: $attemptID" + s" for job: $jobID, partition: $partitionID, attemptNumber: $attemptNumber" + /** + * If a task failed because its attempt to commit was denied, do not count this failure + * towards failing the stage. This is intended to prevent spurious stage failures in cases + * where many speculative tasks are launched and denied to commit. + */ + override def shouldEventuallyFailJob: Boolean = false } /** @@ -202,8 +217,14 @@ case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extend * the task crashed the JVM. */ @DeveloperApi -case class ExecutorLostFailure(execId: String) extends TaskFailedReason { - override def toErrorString: String = s"ExecutorLostFailure (executor ${execId} lost)" +case class ExecutorLostFailure(execId: String, isNormalExit: Boolean = false) + extends TaskFailedReason { + override def toErrorString: String = { + val exitBehavior = if (isNormalExit) "normally" else "abnormally" + s"ExecutorLostFailure (executor ${execId} exited ${exitBehavior})" + } + + override def shouldEventuallyFailJob: Boolean = !isNormalExit } /** diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 888763a3e8ebf..acfe751f6c746 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -24,10 +24,14 @@ import java.util.Arrays import java.util.jar.{JarEntry, JarOutputStream} import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import com.google.common.io.{ByteStreams, Files} import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ import org.apache.spark.util.Utils /** @@ -154,4 +158,51 @@ private[spark] object TestUtils { " @Override public String toString() { return \"" + toStringValue + "\"; }}") createCompiledClass(className, destDir, sourceFile, classpathUrls) } + + /** + * Run some code involving jobs submitted to the given context and assert that the jobs spilled. + */ + def assertSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { + val spillListener = new SpillListener + sc.addSparkListener(spillListener) + body + assert(spillListener.numSpilledStages > 0, s"expected $identifier to spill, but did not") + } + + /** + * Run some code involving jobs submitted to the given context and assert that the jobs + * did not spill. + */ + def assertNotSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = { + val spillListener = new SpillListener + sc.addSparkListener(spillListener) + body + assert(spillListener.numSpilledStages == 0, s"expected $identifier to not spill, but did") + } + +} + + +/** + * A [[SparkListener]] that detects whether spills have occurred in Spark jobs. + */ +private class SpillListener extends SparkListener { + private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]] + private val spilledStageIds = new mutable.HashSet[Int] + + def numSpilledStages: Int = spilledStageIds.size + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + stageIdToTaskMetrics.getOrElseUpdate( + taskEnd.stageId, new ArrayBuffer[TaskMetrics]) += taskEnd.taskMetrics + } + + override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = { + val stageId = stageComplete.stageInfo.stageId + val metrics = stageIdToTaskMetrics.remove(stageId).toSeq.flatten + val spilled = metrics.map(_.memoryBytesSpilled).sum > 0 + if (spilled) { + spilledStageIds += stageId + } + } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index fb787979c1820..8344f6368ac48 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -239,7 +239,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) mapSideCombine: Boolean, serializer: Serializer): JavaPairRDD[K, C] = { implicit val ctag: ClassTag[C] = fakeClassTag - fromRDD(rdd.combineByKey( + fromRDD(rdd.combineByKeyWithClassTag( createCombiner, mergeValue, mergeCombiners, diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b4d152b336602..8464b578ed09e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -24,6 +24,7 @@ import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JM import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials +import scala.util.control.NonFatal import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.conf.Configuration @@ -38,10 +39,9 @@ import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD import org.apache.spark.util.{SerializableConfiguration, Utils} -import scala.util.control.NonFatal private[spark] class PythonRDD( - @transient parent: RDD[_], + parent: RDD[_], command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], @@ -61,11 +61,39 @@ private[spark] class PythonRDD( if (preservePartitoning) firstParent.partitioner else None } + val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + val runner = new PythonRunner( + command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator, + bufferSize, reuse_worker) + runner.compute(firstParent.iterator(split, context), split.index, context) + } +} + + +/** + * A helper class to run Python UDFs in Spark. + */ +private[spark] class PythonRunner( + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + pythonVer: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: Accumulator[JList[Array[Byte]]], + bufferSize: Int, + reuse_worker: Boolean) + extends Logging { + + def compute( + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get - val localdir = env.blockManager.diskBlockManager.localDirs.map( - f => f.getPath()).mkString(",") + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread if (reuse_worker) { envVars.put("SPARK_REUSE_WORKER", "1") @@ -75,7 +103,7 @@ private[spark] class PythonRDD( @volatile var released = false // Start a thread to feed the process input from our parent's iterator - val writerThread = new WriterThread(env, worker, split, context) + val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context) context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() @@ -183,13 +211,16 @@ private[spark] class PythonRDD( new InterruptibleIterator(context, stdoutIterator) } - val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) - /** * The thread responsible for writing the data from the PythonRDD's parent iterator to the * Python process. */ - class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext) + class WriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext) extends Thread(s"stdout writer for $pythonExec") { @volatile private var _exception: Exception = null @@ -211,11 +242,11 @@ private[spark] class PythonRDD( val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index - dataOut.writeInt(split.index) + dataOut.writeInt(partitionIndex) // Python version of driver PythonRDD.writeUTF(pythonVer, dataOut) // sparkFilesDir - PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) + PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) dataOut.writeInt(pythonIncludes.size()) for (include <- pythonIncludes.asScala) { @@ -246,7 +277,7 @@ private[spark] class PythonRDD( dataOut.writeInt(command.length) dataOut.write(command) // Data values - PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) + PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() @@ -327,7 +358,8 @@ private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() - private def getWorkerBroadcasts(worker: Socket) = { + + def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = { synchronized { workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) } @@ -601,7 +633,7 @@ private[spark] object PythonRDD extends Logging { * * The thread will terminate after all the data are sent or any exceptions happen. */ - private def serveIterator[T](items: Iterator[T], threadName: String): Int = { + def serveIterator[T](items: Iterator[T], threadName: String): Int = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 3 seconds serverSocket.setSoTimeout(3000) @@ -785,7 +817,7 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it * collects a list of pickled strings that we pass to Python through a socket. */ -private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) +private class PythonAccumulatorParam(@transient private val serverHost: String, serverPort: Int) extends AccumulatorParam[JList[Array[Byte]]] { Utils.checkHost(serverHost, "Expected hostname") @@ -839,7 +871,8 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: * write the data into disk after deserialization, then Python can read it from disks. */ // scalastyle:off no.finalize -private[spark] class PythonBroadcast(@transient var path: String) extends Serializable { +private[spark] class PythonBroadcast(@transient var path: String) extends Serializable + with Logging { /** * Read data from disks, then copy it to `out` @@ -875,7 +908,9 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial if (!path.isEmpty) { val file = new File(path) if (file.exists()) { - file.delete() + if (!file.delete()) { + logWarning(s"Error deleting ${file.getPath}") + } } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 31e534f160eeb..292ac4cfc35b9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -32,7 +32,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.8.2.1-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.9-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index bb82f3285f1d9..2a792d81994fd 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -125,10 +125,11 @@ private[r] class RBackendHandler(server: RBackend) val methods = cls.getMethods val selectedMethods = methods.filter(m => m.getName == methodName) if (selectedMethods.length > 0) { - val methods = selectedMethods.filter { x => - matchMethod(numArgs, args, x.getParameterTypes) - } - if (methods.isEmpty) { + val index = findMatchedSignature( + selectedMethods.map(_.getParameterTypes), + args) + + if (index.isEmpty) { logWarning(s"cannot find matching method ${cls}.$methodName. " + s"Candidates are:") selectedMethods.foreach { method => @@ -136,18 +137,29 @@ private[r] class RBackendHandler(server: RBackend) } throw new Exception(s"No matched method found for $cls.$methodName") } - val ret = methods.head.invoke(obj, args : _*) + + val ret = selectedMethods(index.get).invoke(obj, args : _*) // Write status bit writeInt(dos, 0) writeObject(dos, ret.asInstanceOf[AnyRef]) } else if (methodName == "") { // methodName should be "" for constructor - val ctor = cls.getConstructors.filter { x => - matchMethod(numArgs, args, x.getParameterTypes) - }.head + val ctors = cls.getConstructors + val index = findMatchedSignature( + ctors.map(_.getParameterTypes), + args) - val obj = ctor.newInstance(args : _*) + if (index.isEmpty) { + logWarning(s"cannot find matching constructor for ${cls}. " + + s"Candidates are:") + ctors.foreach { ctor => + logWarning(s"$cls(${ctor.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched constructor found for $cls") + } + + val obj = ctors(index.get).newInstance(args : _*) writeInt(dos, 0) writeObject(dos, obj.asInstanceOf[AnyRef]) @@ -166,40 +178,79 @@ private[r] class RBackendHandler(server: RBackend) // Read a number of arguments from the data input stream def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { - (0 until numArgs).map { arg => + (0 until numArgs).map { _ => readObject(dis) }.toArray } - // Checks if the arguments passed in args matches the parameter types. - // NOTE: Currently we do exact match. We may add type conversions later. - def matchMethod( - numArgs: Int, - args: Array[java.lang.Object], - parameterTypes: Array[Class[_]]): Boolean = { - if (parameterTypes.length != numArgs) { - return false - } + // Find a matching method signature in an array of signatures of constructors + // or methods of the same name according to the passed arguments. Arguments + // may be converted in order to match a signature. + // + // Note that in Java reflection, constructors and normal methods are of different + // classes, and share no parent class that provides methods for reflection uses. + // There is no unified way to handle them in this function. So an array of signatures + // is passed in instead of an array of candidate constructors or methods. + // + // Returns an Option[Int] which is the index of the matched signature in the array. + def findMatchedSignature( + parameterTypesOfMethods: Array[Array[Class[_]]], + args: Array[Object]): Option[Int] = { + val numArgs = args.length + + for (index <- 0 until parameterTypesOfMethods.length) { + val parameterTypes = parameterTypesOfMethods(index) + + if (parameterTypes.length == numArgs) { + var argMatched = true + var i = 0 + while (i < numArgs && argMatched) { + val parameterType = parameterTypes(i) + + if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) { + // The case that the parameter type is a Scala Seq and the argument + // is a Java array is considered matching. The array will be converted + // to a Seq later if this method is matched. + } else { + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Integer] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if (!parameterWrapperType.isInstance(args(i))) { + argMatched = false + } + } - for (i <- 0 to numArgs - 1) { - val parameterType = parameterTypes(i) - var parameterWrapperType = parameterType - - // Convert native parameters to Object types as args is Array[Object] here - if (parameterType.isPrimitive) { - parameterWrapperType = parameterType match { - case java.lang.Integer.TYPE => classOf[java.lang.Integer] - case java.lang.Long.TYPE => classOf[java.lang.Integer] - case java.lang.Double.TYPE => classOf[java.lang.Double] - case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] - case _ => parameterType + i = i + 1 + } + + if (argMatched) { + // For now, we return the first matching method. + // TODO: find best method in matching methods. + + // Convert args if needed + val parameterTypes = parameterTypesOfMethods(index) + + (0 until numArgs).map { i => + if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) { + // Convert a Java array to scala Seq + args(i) = args(i).asInstanceOf[Array[_]].toSeq + } + } + + return Some(index) } - } - if (!parameterWrapperType.isInstance(args(i))) { - return false } } - true + None } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 9e807cc52f18c..fd5646b5b6372 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -44,7 +44,7 @@ private[spark] object RUtils { (sys.props("spark.master"), sys.props("spark.submit.deployMode")) } else { val sparkConf = SparkEnv.get.conf - (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode")) + (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode", "client")) } val isYarnCluster = master != null && master.contains("yarn") && deployMode == "cluster" diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 26ad4f1d4697e..da126bac7ad1f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -21,11 +21,20 @@ import java.io.{DataInputStream, DataOutputStream} import java.sql.{Timestamp, Date, Time} import scala.collection.JavaConverters._ +import scala.collection.mutable.WrappedArray /** * Utility functions to serialize, deserialize objects to / from R */ private[spark] object SerDe { + type ReadObject = (DataInputStream, Char) => Object + type WriteObject = (DataOutputStream, Object) => Boolean + + var sqlSerDe: (ReadObject, WriteObject) = _ + + def registerSqlSerDe(sqlSerDe: (ReadObject, WriteObject)): Unit = { + this.sqlSerDe = sqlSerDe + } // Type mapping from R to Java // @@ -62,11 +71,22 @@ private[spark] object SerDe { case 'c' => readString(dis) case 'e' => readMap(dis) case 'r' => readBytes(dis) + case 'a' => readArray(dis) case 'l' => readList(dis) case 'D' => readDate(dis) case 't' => readTime(dis) case 'j' => JVMObjectTracker.getObject(readString(dis)) - case _ => throw new IllegalArgumentException(s"Invalid type $dataType") + case _ => + if (sqlSerDe == null || sqlSerDe._1 == null) { + throw new IllegalArgumentException (s"Invalid type $dataType") + } else { + val obj = (sqlSerDe._1)(dis, dataType) + if (obj == null) { + throw new IllegalArgumentException (s"Invalid type $dataType") + } else { + obj + } + } } } @@ -140,7 +160,8 @@ private[spark] object SerDe { (0 until len).map(_ => readString(in)).toArray } - def readList(dis: DataInputStream): Array[_] = { + // All elements of an array must be of the same type + def readArray(dis: DataInputStream): Array[_] = { val arrType = readObjectType(dis) arrType match { case 'i' => readIntArr(dis) @@ -149,26 +170,43 @@ private[spark] object SerDe { case 'b' => readBooleanArr(dis) case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) case 'r' => readBytesArr(dis) - case 'l' => { + case 'a' => + val len = readInt(dis) + (0 until len).map(_ => readArray(dis)).toArray + case 'l' => val len = readInt(dis) (0 until len).map(_ => readList(dis)).toArray - } - case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") + case _ => + if (sqlSerDe == null || sqlSerDe._1 == null) { + throw new IllegalArgumentException (s"Invalid array type $arrType") + } else { + val len = readInt(dis) + (0 until len).map { _ => + val obj = (sqlSerDe._1)(dis, arrType) + if (obj == null) { + throw new IllegalArgumentException (s"Invalid array type $arrType") + } else { + obj + } + }.toArray + } } } + // Each element of a list can be of different type. They are all represented + // as Object on JVM side + def readList(dis: DataInputStream): Array[Object] = { + val len = readInt(dis) + (0 until len).map(_ => readObject(dis)).toArray + } + def readMap(in: DataInputStream): java.util.Map[Object, Object] = { val len = readInt(in) if (len > 0) { - val keysType = readObjectType(in) - val keysLen = readInt(in) - val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) - - val valuesLen = readInt(in) - val values = (0 until valuesLen).map(_ => { - val valueType = readObjectType(in) - readTypedObject(in, valueType) - }) + // Keys is an array of String + val keys = readArray(in).asInstanceOf[Array[Object]] + val values = readList(in) + keys.zip(values).toMap.asJava } else { new java.util.HashMap[Object, Object]() @@ -208,98 +246,139 @@ private[spark] object SerDe { case "array" => dos.writeByte('a') // Array of objects case "list" => dos.writeByte('l') + case "map" => dos.writeByte('e') case "jobj" => dos.writeByte('j') case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") } } - def writeObject(dos: DataOutputStream, value: Object): Unit = { - if (value == null) { + private def writeKeyValue(dos: DataOutputStream, key: Object, value: Object): Unit = { + if (key == null) { + throw new IllegalArgumentException("Key in map can't be null.") + } else if (!key.isInstanceOf[String]) { + throw new IllegalArgumentException(s"Invalid map key type: ${key.getClass.getName}") + } + + writeString(dos, key.asInstanceOf[String]) + writeObject(dos, value) + } + + def writeObject(dos: DataOutputStream, obj: Object): Unit = { + if (obj == null) { writeType(dos, "void") } else { - value.getClass.getName match { - case "java.lang.Character" => + // Convert ArrayType collected from DataFrame to Java array + // Collected data of ArrayType from a DataFrame is observed to be of + // type "scala.collection.mutable.WrappedArray" + val value = + if (obj.isInstanceOf[WrappedArray[_]]) { + obj.asInstanceOf[WrappedArray[_]].toArray + } else { + obj + } + + value match { + case v: java.lang.Character => writeType(dos, "character") - writeString(dos, value.asInstanceOf[Character].toString) - case "java.lang.String" => + writeString(dos, v.toString) + case v: java.lang.String => writeType(dos, "character") - writeString(dos, value.asInstanceOf[String]) - case "java.lang.Long" => + writeString(dos, v) + case v: java.lang.Long => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Long].toDouble) - case "java.lang.Float" => + writeDouble(dos, v.toDouble) + case v: java.lang.Float => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Float].toDouble) - case "java.math.BigDecimal" => + writeDouble(dos, v.toDouble) + case v: java.math.BigDecimal => writeType(dos, "double") - val javaDecimal = value.asInstanceOf[java.math.BigDecimal] - writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble) - case "java.lang.Double" => + writeDouble(dos, scala.math.BigDecimal(v).toDouble) + case v: java.lang.Double => writeType(dos, "double") - writeDouble(dos, value.asInstanceOf[Double]) - case "java.lang.Byte" => + writeDouble(dos, v) + case v: java.lang.Byte => writeType(dos, "integer") - writeInt(dos, value.asInstanceOf[Byte].toInt) - case "java.lang.Short" => + writeInt(dos, v.toInt) + case v: java.lang.Short => writeType(dos, "integer") - writeInt(dos, value.asInstanceOf[Short].toInt) - case "java.lang.Integer" => + writeInt(dos, v.toInt) + case v: java.lang.Integer => writeType(dos, "integer") - writeInt(dos, value.asInstanceOf[Int]) - case "java.lang.Boolean" => + writeInt(dos, v) + case v: java.lang.Boolean => writeType(dos, "logical") - writeBoolean(dos, value.asInstanceOf[Boolean]) - case "java.sql.Date" => + writeBoolean(dos, v) + case v: java.sql.Date => writeType(dos, "date") - writeDate(dos, value.asInstanceOf[Date]) - case "java.sql.Time" => + writeDate(dos, v) + case v: java.sql.Time => writeType(dos, "time") - writeTime(dos, value.asInstanceOf[Time]) - case "java.sql.Timestamp" => + writeTime(dos, v) + case v: java.sql.Timestamp => writeType(dos, "time") - writeTime(dos, value.asInstanceOf[Timestamp]) + writeTime(dos, v) // Handle arrays // Array of primitive types // Special handling for byte array - case "[B" => + case v: Array[Byte] => writeType(dos, "raw") - writeBytes(dos, value.asInstanceOf[Array[Byte]]) + writeBytes(dos, v) - case "[C" => + case v: Array[Char] => writeType(dos, "array") - writeStringArr(dos, value.asInstanceOf[Array[Char]].map(_.toString)) - case "[S" => + writeStringArr(dos, v.map(_.toString)) + case v: Array[Short] => writeType(dos, "array") - writeIntArr(dos, value.asInstanceOf[Array[Short]].map(_.toInt)) - case "[I" => + writeIntArr(dos, v.map(_.toInt)) + case v: Array[Int] => writeType(dos, "array") - writeIntArr(dos, value.asInstanceOf[Array[Int]]) - case "[J" => + writeIntArr(dos, v) + case v: Array[Long] => writeType(dos, "array") - writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) - case "[F" => + writeDoubleArr(dos, v.map(_.toDouble)) + case v: Array[Float] => writeType(dos, "array") - writeDoubleArr(dos, value.asInstanceOf[Array[Float]].map(_.toDouble)) - case "[D" => + writeDoubleArr(dos, v.map(_.toDouble)) + case v: Array[Double] => writeType(dos, "array") - writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) - case "[Z" => + writeDoubleArr(dos, v) + case v: Array[Boolean] => writeType(dos, "array") - writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) + writeBooleanArr(dos, v) // Array of objects, null objects use "void" type - case c if c.startsWith("[") => + case v: Array[Object] => writeType(dos, "list") - val array = value.asInstanceOf[Array[Object]] - writeInt(dos, array.length) - array.foreach(elem => writeObject(dos, elem)) + writeInt(dos, v.length) + v.foreach(elem => writeObject(dos, elem)) + + // Handle map + case v: java.util.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + val iter = v.entrySet.iterator + while(iter.hasNext) { + val entry = iter.next + val key = entry.getKey + val value = entry.getValue + + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + } + case v: scala.collection.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + v.foreach { case (key, value) => + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + } case _ => - writeType(dos, "jobj") - writeJObj(dos, value) + if (sqlSerDe == null || sqlSerDe._2 == null || !(sqlSerDe._2)(dos, value)) { + writeType(dos, "jobj") + writeJObj(dos, value) + } } } } @@ -329,12 +408,11 @@ private[spark] object SerDe { out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) } - // NOTE: Only works for ASCII right now def writeString(out: DataOutputStream, value: String): Unit = { - val len = value.length - out.writeInt(len + 1) // For the \0 - out.writeBytes(value) - out.writeByte(0) + val utf8 = value.getBytes("UTF-8") + val len = utf8.length + out.writeInt(len) + out.write(utf8, 0, len) } def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index d8084a57658ad..3feb7cea593e0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -69,9 +69,14 @@ private[deploy] object DeployMessages { // Master to Worker + sealed trait RegisterWorkerResponse + case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage + with RegisterWorkerResponse + + case class RegisterWorkerFailed(message: String) extends DeployMessage with RegisterWorkerResponse - case class RegisterWorkerFailed(message: String) extends DeployMessage + case object MasterInStandby extends DeployMessage with RegisterWorkerResponse case class ReconnectWorker(masterUrl: String) extends DeployMessage diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala index 4b28866dcaa7c..7d160b6790eaa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala @@ -175,8 +175,10 @@ private[deploy] object RPackageUtils extends Logging { print(s"ERROR: Failed to build R package in $file.", printStream) print(RJarDoc, printStream) } - } finally { - rSource.delete() // clean up + } finally { // clean up + if (!rSource.delete()) { + logWarning(s"Error deleting ${rSource.getPath()}") + } } } else { if (verbose) { @@ -211,7 +213,9 @@ private[deploy] object RPackageUtils extends Logging { val filesToBundle = listFilesRecursively(dir, Seq(".zip")) // create a zip file from scratch, do not append to existing file. val zipFile = new File(dir, name) - zipFile.delete() + if (!zipFile.delete()) { + logWarning(s"Error deleting ${zipFile.getPath()}") + } val zipOutputStream = new ZipOutputStream(new FileOutputStream(zipFile, false)) try { filesToBundle.foreach { file => diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index 05b954ce36998..ed183cf16a9cb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.spark.api.r.{RBackend, RUtils} +import org.apache.spark.{SparkException, SparkUserAppException} import org.apache.spark.util.RedirectThread /** @@ -39,7 +40,16 @@ object RRunner { // Time to wait for SparkR backend to initialize in seconds val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt - val rCommand = "Rscript" + val rCommand = { + // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command", + // but kept here for backward compatibility. + var cmd = sys.props.getOrElse("spark.sparkr.r.command", "Rscript") + cmd = sys.props.getOrElse("spark.r.command", cmd) + if (sys.props.getOrElse("spark.submit.deployMode", "client") == "client") { + cmd = sys.props.getOrElse("spark.r.driver.command", cmd) + } + cmd + } // Check if the file path exists. // If not, change directory to current working directory for YARN cluster mode @@ -84,12 +94,15 @@ object RRunner { } finally { sparkRBackend.close() } - System.exit(returnCode) + if (returnCode != 0) { + throw new SparkUserAppException(returnCode) + } } else { + val errorMessage = s"SparkR backend did not initialize in $backendTimeout seconds" // scalastyle:off println - System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds") + System.err.println(errorMessage) // scalastyle:on println - System.exit(-1) + throw new SparkException(errorMessage) } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index f7723ef5bde4c..d606b80c03c98 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -192,7 +192,9 @@ class SparkHadoopUtil extends Logging { * while it's interface in Hadoop 2.+. */ def getConfigurationFromJobContext(context: JobContext): Configuration = { + // scalastyle:off jobconfig val method = context.getClass.getMethod("getConfiguration") + // scalastyle:on jobconfig method.invoke(context).asInstanceOf[Configuration] } @@ -204,7 +206,9 @@ class SparkHadoopUtil extends Logging { */ def getTaskAttemptIDFromTaskAttemptContext( context: MapReduceTaskAttemptContext): MapReduceTaskAttemptID = { + // scalastyle:off jobconfig val method = context.getClass.getMethod("getTaskAttemptID") + // scalastyle:on jobconfig method.invoke(context).asInstanceOf[MapReduceTaskAttemptID] } @@ -381,20 +385,13 @@ class SparkHadoopUtil extends Logging { object SparkHadoopUtil { - private val hadoop = { - val yarnMode = java.lang.Boolean.valueOf( - System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) - if (yarnMode) { - try { - Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") - .newInstance() - .asInstanceOf[SparkHadoopUtil] - } catch { - case e: Exception => throw new SparkException("Unable to load YARN support", e) - } - } else { - new SparkHadoopUtil - } + private lazy val hadoop = new SparkHadoopUtil + private lazy val yarn = try { + Utils.classForName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") + .newInstance() + .asInstanceOf[SparkHadoopUtil] + } catch { + case e: Exception => throw new SparkException("Unable to load YARN support", e) } val SPARK_YARN_CREDS_TEMP_EXTENSION = ".tmp" @@ -402,6 +399,13 @@ object SparkHadoopUtil { val SPARK_YARN_CREDS_COUNTER_DELIM = "-" def get: SparkHadoopUtil = { - hadoop + // Check each time to support changing to/from YARN + val yarnMode = java.lang.Boolean.valueOf( + System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) + if (yarnMode) { + yarn + } else { + hadoop + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 86fcf942c2c4e..640cc325281a9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -319,8 +319,8 @@ object SparkSubmit { // The following modes are not supported or applicable (clusterManager, deployMode) match { - case (MESOS, CLUSTER) if args.isPython => - printErrorAndExit("Cluster deploy mode is currently not supported for python " + + case (MESOS, CLUSTER) if args.isR => + printErrorAndExit("Cluster deploy mode is currently not supported for R " + "applications on Mesos clusters.") case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + @@ -551,7 +551,15 @@ object SparkSubmit { if (isMesosCluster) { assert(args.useRest, "Mesos cluster mode is only supported through the REST submission API") childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" - childArgs += (args.primaryResource, args.mainClass) + if (args.isPython) { + // Second argument is main class + childArgs += (args.primaryResource, "") + if (args.pyFiles != null) { + sysProps("spark.submit.pyFiles") = args.pyFiles + } + } else { + childArgs += (args.primaryResource, args.mainClass) + } if (args.childArgs != null) { childArgs ++= args.childArgs } @@ -647,6 +655,15 @@ object SparkSubmit { // scalastyle:on println } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) + case e: NoClassDefFoundError => + e.printStackTrace(printStream) + if (e.getMessage.contains("org/apache/hadoop/hive")) { + // scalastyle:off println + printStream.println(s"Failed to load hive class.") + printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.") + // scalastyle:on println + } + System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } // SPARK-4170 diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index e573ff16c50a3..80bfda9dddb39 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.history import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream} +import java.util.UUID import java.util.concurrent.{ExecutorService, Executors, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -26,7 +27,7 @@ import scala.collection.mutable import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} -import org.apache.hadoop.fs.permission.AccessControlException +import org.apache.hadoop.security.AccessControlException import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil @@ -73,7 +74,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // The modification time of the newest log detected during the last scan. This is used // to ignore logs that are older during subsequent scans, to avoid processing data that // is already known. - private var lastModifiedTime = -1L + private var lastScanTime = -1L // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted // into the map in order, so the LinkedHashMap maintains the correct ordering. @@ -145,16 +146,15 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val ui = { val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) - SparkUI.createHistoryUI(conf, replayBus, appSecManager, appId, + SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.name, HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } val appListener = new ApplicationEventListener() replayBus.addListener(appListener) - val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus) - appInfo.map { info => - ui.setAppName(s"${info.name} ($appId)") - + val appAttemptInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), + replayBus) + appAttemptInfo.map { info => val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) ui.getSecurityManager.setAcls(uiAclsEnabled) // make sure to set admin acls before view acls so they are properly picked up @@ -179,15 +179,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) */ private[history] def checkForLogs(): Unit = { try { + val newLastScanTime = getNewLastScanTime() val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq) .getOrElse(Seq[FileStatus]()) - var newLastModifiedTime = lastModifiedTime val logInfos: Seq[FileStatus] = statusList .filter { entry => try { getModificationTime(entry).map { time => - newLastModifiedTime = math.max(newLastModifiedTime, time) - time >= lastModifiedTime + time >= lastScanTime }.getOrElse(false) } catch { case e: AccessControlException => @@ -224,12 +223,31 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - lastModifiedTime = newLastModifiedTime + lastScanTime = newLastScanTime } catch { case e: Exception => logError("Exception in checking for event log updates", e) } } + private def getNewLastScanTime(): Long = { + val fileName = "." + UUID.randomUUID().toString + val path = new Path(logDir, fileName) + val fos = fs.create(path) + + try { + fos.close() + fs.getFileStatus(path).getModificationTime + } catch { + case e: Exception => + logError("Exception encountered when attempting to update last scan time", e) + lastScanTime + } finally { + if (!fs.delete(path)) { + logWarning(s"Error deleting ${path}") + } + } + } + override def writeEventLogs( appId: String, attemptId: Option[String], @@ -389,7 +407,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) try { val path = new Path(logDir, attempt.logPath) if (fs.exists(path)) { - fs.delete(path, true) + if (!fs.delete(path, true)) { + logWarning(s"Error deleting ${path}") + } } } catch { case e: AccessControlException => diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 0830cc1ba1245..b347cb3be69f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -51,7 +51,10 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val hasMultipleAttempts = appsToShow.exists(_.attempts.size > 1) val appTable = if (hasMultipleAttempts) { - UIUtils.listingTable(appWithAttemptHeader, appWithAttemptRow, appsToShow) + // Sorting is disable here as table sort on rowspan has issues. + // ref. SPARK-10172 + UIUtils.listingTable(appWithAttemptHeader, appWithAttemptRow, + appsToShow, sortable = false) } else { UIUtils.listingTable(appHeader, appRow, appsToShow) } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 18265df9faa2c..d03bab3820bb2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -30,28 +30,35 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin parse(args.toList) private def parse(args: List[String]): Unit = { - args match { - case ("--dir" | "-d") :: value :: tail => - logWarning("Setting log directory through the command line is deprecated as of " + - "Spark 1.1.0. Please set this through spark.history.fs.logDirectory instead.") - conf.set("spark.history.fs.logDirectory", value) - System.setProperty("spark.history.fs.logDirectory", value) - parse(tail) + if (args.length == 1) { + setLogDirectory(args.head) + } else { + args match { + case ("--dir" | "-d") :: value :: tail => + setLogDirectory(value) + parse(tail) - case ("--help" | "-h") :: tail => - printUsageAndExit(0) + case ("--help" | "-h") :: tail => + printUsageAndExit(0) - case ("--properties-file") :: value :: tail => - propertiesFile = value - parse(tail) + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) - case Nil => + case Nil => - case _ => - printUsageAndExit(1) + case _ => + printUsageAndExit(1) + } } } + private def setLogDirectory(value: String): Unit = { + logWarning("Setting log directory through the command line is deprecated as of " + + "Spark 1.1.0. Please set this through spark.history.fs.logDirectory instead.") + conf.set("spark.history.fs.logDirectory", value) + } + // This mutates the SparkConf, so all accesses to it must be made after this line Utils.loadDefaultSparkProperties(conf, propertiesFile) @@ -62,6 +69,8 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin |Usage: HistoryServer [options] | |Options: + | DIR Deprecated; set spark.history.fs.logDirectory directly + | --dir DIR (-d DIR) Deprecated; set spark.history.fs.logDirectory directly | --properties-file FILE Path to a custom Spark properties file. | Default is conf/spark-defaults.conf. | @@ -90,3 +99,4 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin } } + diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index aa379d4cd61e7..1aa8cd5013b49 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -45,7 +45,10 @@ private[master] class FileSystemPersistenceEngine( } override def unpersist(name: String): Unit = { - new File(dir + File.separator + name).delete() + val f = new File(dir + File.separator + name) + if (!f.delete()) { + logWarning(s"Error deleting ${f.getPath()}") + } } override def read[T: ClassTag](prefix: String): Seq[T] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 26904d39a9bec..6715d6c70f497 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -233,31 +233,6 @@ private[deploy] class Master( System.exit(0) } - case RegisterWorker( - id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => { - logInfo("Registering worker %s:%d with %d cores, %s RAM".format( - workerHost, workerPort, cores, Utils.megabytesToString(memory))) - if (state == RecoveryState.STANDBY) { - // ignore, don't send response - } else if (idToWorker.contains(id)) { - workerRef.send(RegisterWorkerFailed("Duplicate worker ID")) - } else { - val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, - workerRef, workerUiPort, publicAddress) - if (registerWorker(worker)) { - persistenceEngine.addWorker(worker) - workerRef.send(RegisteredWorker(self, masterWebUiUrl)) - schedule() - } else { - val workerAddress = worker.endpoint.address - logWarning("Worker registration failed. Attempted to re-register worker at same " + - "address: " + workerAddress) - workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: " - + workerAddress)) - } - } - } - case RegisterApplication(description, driver) => { // TODO Prevent repeated registrations from some driver if (state == RecoveryState.STANDBY) { @@ -387,6 +362,31 @@ private[deploy] class Master( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterWorker( + id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => { + logInfo("Registering worker %s:%d with %d cores, %s RAM".format( + workerHost, workerPort, cores, Utils.megabytesToString(memory))) + if (state == RecoveryState.STANDBY) { + context.reply(MasterInStandby) + } else if (idToWorker.contains(id)) { + context.reply(RegisterWorkerFailed("Duplicate worker ID")) + } else { + val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, + workerRef, workerUiPort, publicAddress) + if (registerWorker(worker)) { + persistenceEngine.addWorker(worker) + context.reply(RegisteredWorker(self, masterWebUiUrl)) + schedule() + } else { + val workerAddress = worker.endpoint.address + logWarning("Worker registration failed. Attempted to re-register worker at same " + + "address: " + workerAddress) + context.reply(RegisterWorkerFailed("Attempted to re-register worker at same address: " + + workerAddress)) + } + } + } + case RequestSubmitDriver(description) => { if (state != RecoveryState.ALIVE) { val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + @@ -944,7 +944,7 @@ private[deploy] class Master( val logInput = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) val replayBus = new ReplayListenerBus() val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), - appName + status, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) + appName, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) val maybeTruncated = eventLogFile.endsWith(EventLoggingListener.IN_PROGRESS) try { replayBus.replay(logInput, eventLogFile, maybeTruncated) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 1fe956320a1b8..957a928bc402b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -392,15 +392,14 @@ private[spark] object RestSubmissionClient { mainClass: String, appArgs: Array[String], conf: SparkConf, - env: Map[String, String] = sys.env): SubmitRestProtocolResponse = { + env: Map[String, String] = Map()): SubmitRestProtocolResponse = { val master = conf.getOption("spark.master").getOrElse { throw new IllegalArgumentException("'spark.master' must be set.") } val sparkProperties = conf.getAll.toMap - val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") } val client = new RestSubmissionClient(master) val submitRequest = client.constructSubmitRequest( - appResource, mainClass, appArgs, sparkProperties, environmentVariables) + appResource, mainClass, appArgs, sparkProperties, env) client.createSubmission(submitRequest) } @@ -413,6 +412,16 @@ private[spark] object RestSubmissionClient { val mainClass = args(1) val appArgs = args.slice(2, args.size) val conf = new SparkConf - run(appResource, mainClass, appArgs, conf) + val env = filterSystemEnvironment(sys.env) + run(appResource, mainClass, appArgs, conf, env) + } + + /** + * Filter non-spark environment variables from any environment. + */ + private[rest] def filterSystemEnvironment(env: Map[String, String]): Map[String, String] = { + env.filter { case (k, _) => + (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED") || k.startsWith("MESOS_") + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 770927c80f7a4..a45867e7680ec 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -26,7 +26,7 @@ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFut import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} import scala.concurrent.ExecutionContext -import scala.util.Random +import scala.util.{Failure, Random, Success} import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -213,8 +213,7 @@ private[deploy] class Worker( logInfo("Connecting to master " + masterAddress + "...") val masterEndpoint = rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) - masterEndpoint.send(RegisterWorker( - workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + registerWithMaster(masterEndpoint) } catch { case ie: InterruptedException => // Cancelled case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) @@ -271,8 +270,7 @@ private[deploy] class Worker( logInfo("Connecting to master " + masterAddress + "...") val masterEndpoint = rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) - masterEndpoint.send(RegisterWorker( - workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + registerWithMaster(masterEndpoint) } catch { case ie: InterruptedException => // Cancelled case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) @@ -329,7 +327,7 @@ private[deploy] class Worker( registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate( new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { - self.send(ReregisterWithMaster) + Option(self).foreach(_.send(ReregisterWithMaster)) } }, INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, @@ -341,25 +339,54 @@ private[deploy] class Worker( } } - override def receive: PartialFunction[Any, Unit] = { - case RegisteredWorker(masterRef, masterWebUiUrl) => - logInfo("Successfully registered with master " + masterRef.address.toSparkURL) - registered = true - changeMaster(masterRef, masterWebUiUrl) - forwordMessageScheduler.scheduleAtFixedRate(new Runnable { - override def run(): Unit = Utils.tryLogNonFatalError { - self.send(SendHeartbeat) - } - }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) - if (CLEANUP_ENABLED) { - logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir") + private def registerWithMaster(masterEndpoint: RpcEndpointRef): Unit = { + masterEndpoint.ask[RegisterWorkerResponse](RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + .onComplete { + // This is a very fast action so we can use "ThreadUtils.sameThread" + case Success(msg) => + Utils.tryLogNonFatalError { + handleRegisterResponse(msg) + } + case Failure(e) => + logError(s"Cannot register with master: ${masterEndpoint.address}", e) + System.exit(1) + }(ThreadUtils.sameThread) + } + + private def handleRegisterResponse(msg: RegisterWorkerResponse): Unit = synchronized { + msg match { + case RegisteredWorker(masterRef, masterWebUiUrl) => + logInfo("Successfully registered with master " + masterRef.address.toSparkURL) + registered = true + changeMaster(masterRef, masterWebUiUrl) forwordMessageScheduler.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { - self.send(WorkDirCleanup) + self.send(SendHeartbeat) } - }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) - } + }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) + if (CLEANUP_ENABLED) { + logInfo( + s"Worker cleanup enabled; old application directories will be deleted in: $workDir") + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(WorkDirCleanup) + } + }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) + } + case RegisterWorkerFailed(message) => + if (!registered) { + logError("Worker registration failed: " + message) + System.exit(1) + } + + case MasterInStandby => + // Ignore. Master not yet ready. + } + } + + override def receive: PartialFunction[Any, Unit] = synchronized { case SendHeartbeat => if (connected) { sendToMaster(Heartbeat(workerId, self)) } @@ -399,12 +426,6 @@ private[deploy] class Worker( map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) masterRef.send(WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq)) - case RegisterWorkerFailed(message) => - if (!registered) { - logError("Worker registration failed: " + message) - System.exit(1) - } - case ReconnectWorker(masterUrl) => logInfo(s"Master with url $masterUrl requested this worker to reconnect.") registerWithMaster() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 735c4f0927150..ab56fde938bae 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -24,14 +24,13 @@ import org.apache.spark.rpc._ * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) +private[spark] class WorkerWatcher( + override val rpcEnv: RpcEnv, workerUrl: String, isTesting: Boolean = false) extends RpcEndpoint with Logging { - override def onStart() { - logInfo(s"Connecting to worker $workerUrl") - if (!isTesting) { - rpcEnv.asyncSetupEndpointRefByURI(workerUrl) - } + logInfo(s"Connecting to worker $workerUrl") + if (!isTesting) { + rpcEnv.asyncSetupEndpointRefByURI(workerUrl) } // Used to avoid shutting down JVM during tests @@ -40,8 +39,6 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin // true rather than calling `System.exit`. The user can check `isShutDown` to know if // `exitNonZero` is called. private[deploy] var isShutDown = false - private[deploy] def setTesting(testing: Boolean) = isTesting = testing - private var isTesting = false // Lets filter events only from the worker's rpc system private val expectedAddress = RpcAddress.fromURIString(workerUrl) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 709a27233598c..1a0598e50dcf1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -20,9 +20,8 @@ package org.apache.spark.deploy.worker.ui import java.io.File import javax.servlet.http.HttpServletRequest -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.Logging import org.apache.spark.deploy.worker.Worker -import org.apache.spark.deploy.worker.ui.WorkerWebUI._ import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.RpcUtils @@ -49,7 +48,9 @@ class WorkerWebUI( attachPage(new WorkerPage(this)) attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static")) attachHandler(createServletHandler("/log", - (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr)) + (request: HttpServletRequest) => logPage.renderLog(request), + worker.securityMgr, + worker.conf)) } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index fcd76ec52742a..a9c6a05ecd434 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -59,12 +59,12 @@ private[spark] class CoarseGrainedExecutorBackend( rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => // This is a very fast action so we can use "ThreadUtils.sameThread" driver = Some(ref) - ref.ask[RegisteredExecutor.type]( + ref.ask[RegisterExecutorResponse]( RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) }(ThreadUtils.sameThread).onComplete { // This is a very fast action so we can use "ThreadUtils.sameThread" case Success(msg) => Utils.tryLogNonFatalError { - Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor + Option(self).foreach(_.send(msg)) // msg must be RegisterExecutorResponse } case Failure(e) => { logError(s"Cannot register with driver: $driverUrl", e) @@ -110,6 +110,11 @@ private[spark] class CoarseGrainedExecutorBackend( case StopExecutor => logInfo("Driver commanded a shutdown") + // Cannot shutdown here because an ack may need to be sent back to the caller. So send + // a message to self to actually do the shutdown. + self.send(Shutdown) + + case Shutdown => executor.stop() stop() rpcEnv.shutdown() diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala index f47d7ef511da1..7d84889a2def0 100644 --- a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala +++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala @@ -26,8 +26,8 @@ private[spark] class CommitDeniedException( msg: String, jobID: Int, splitID: Int, - attemptID: Int) + attemptNumber: Int) extends Exception(msg) { - def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptID) + def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptNumber) } diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index a5ad47293f1c2..e2ffc3b64e5db 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -131,8 +131,8 @@ private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDat */ @Experimental class PortableDataStream( - @transient isplit: CombineFileSplit, - @transient context: TaskAttemptContext, + isplit: CombineFileSplit, + context: TaskAttemptContext, index: Integer) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 607d5a321efca..9dc36704a676d 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -148,7 +148,7 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { try { Snappy.getNativeLibraryVersion } catch { - case e: Error => throw new IllegalArgumentException + case e: Error => throw new IllegalArgumentException(e) } override def compressedOutputStream(s: OutputStream): OutputStream = { diff --git a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala new file mode 100644 index 0000000000000..3ea984c501e02 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +import java.net.{InetAddress, Socket} + +import org.apache.spark.SPARK_VERSION +import org.apache.spark.launcher.LauncherProtocol._ +import org.apache.spark.util.ThreadUtils + +/** + * A class that can be used to talk to a launcher server. Users should extend this class to + * provide implementation for the abstract methods. + * + * See `LauncherServer` for an explanation of how launcher communication works. + */ +private[spark] abstract class LauncherBackend { + + private var clientThread: Thread = _ + private var connection: BackendConnection = _ + private var lastState: SparkAppHandle.State = _ + @volatile private var _isConnected = false + + def connect(): Unit = { + val port = sys.env.get(LauncherProtocol.ENV_LAUNCHER_PORT).map(_.toInt) + val secret = sys.env.get(LauncherProtocol.ENV_LAUNCHER_SECRET) + if (port != None && secret != None) { + val s = new Socket(InetAddress.getLoopbackAddress(), port.get) + connection = new BackendConnection(s) + connection.send(new Hello(secret.get, SPARK_VERSION)) + clientThread = LauncherBackend.threadFactory.newThread(connection) + clientThread.start() + _isConnected = true + } + } + + def close(): Unit = { + if (connection != null) { + try { + connection.close() + } finally { + if (clientThread != null) { + clientThread.join() + } + } + } + } + + def setAppId(appId: String): Unit = { + if (connection != null) { + connection.send(new SetAppId(appId)) + } + } + + def setState(state: SparkAppHandle.State): Unit = { + if (connection != null && lastState != state) { + connection.send(new SetState(state)) + lastState = state + } + } + + /** Return whether the launcher handle is still connected to this backend. */ + def isConnected(): Boolean = _isConnected + + /** + * Implementations should provide this method, which should try to stop the application + * as gracefully as possible. + */ + protected def onStopRequest(): Unit + + /** + * Callback for when the launcher handle disconnects from this backend. + */ + protected def onDisconnected() : Unit = { } + + + private class BackendConnection(s: Socket) extends LauncherConnection(s) { + + override protected def handle(m: Message): Unit = m match { + case _: Stop => + onStopRequest() + + case _ => + throw new IllegalArgumentException(s"Unexpected message type: ${m.getClass().getName()}") + } + + override def close(): Unit = { + try { + super.close() + } finally { + onDisconnected() + _isConnected = false + } + } + + } + +} + +private object LauncherBackend { + + val threadFactory = ThreadUtils.namedThreadFactory("LauncherBackend") + +} diff --git a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala index 0c096656f9236..a2add61617281 100644 --- a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala +++ b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala @@ -40,7 +40,7 @@ private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, comm cmd.add(s"-Xms${memoryMb}M") cmd.add(s"-Xmx${memoryMb}M") command.javaOpts.foreach(cmd.add) - addPermGenSizeOpt(cmd) + CommandBuilderUtils.addPermGenSizeOpt(cmd) addOptionString(cmd, getenv("SPARK_JAVA_OPTS")) cmd } diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index f405b732e4725..f7298e8d5c62c 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -91,8 +91,7 @@ object SparkHadoopMapRedUtil extends Logging { committer: MapReduceOutputCommitter, mrTaskContext: MapReduceTaskAttemptContext, jobId: Int, - splitId: Int, - attemptId: Int): Unit = { + splitId: Int): Unit = { val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext) @@ -122,7 +121,8 @@ object SparkHadoopMapRedUtil extends Logging { if (shouldCoordinateWithDriver) { val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator - val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, attemptId) + val taskAttemptNumber = TaskContext.get().attemptNumber() + val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, taskAttemptNumber) if (canCommit) { performCommit() @@ -132,7 +132,7 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(message) // We need to abort the task so that the driver can reschedule new attempts, if necessary committer.abortTask(mrTaskContext) - throw new CommitDeniedException(message, jobId, splitId, attemptId) + throw new CommitDeniedException(message, jobId, splitId, taskAttemptNumber) } } else { // Speculation is disabled or a user has chosen to manually bypass the commit coordination @@ -143,16 +143,4 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID") } } - - def commitTask( - committer: MapReduceOutputCommitter, - mrTaskContext: MapReduceTaskAttemptContext, - sparkTaskContext: TaskContext): Unit = { - commitTask( - committer, - mrTaskContext, - sparkTaskContext.stageId(), - sparkTaskContext.partitionId(), - sparkTaskContext.attemptNumber()) - } } diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala new file mode 100644 index 0000000000000..7168ac549106f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} + + +/** + * An abstract memory manager that enforces how memory is shared between execution and storage. + * + * In this context, execution memory refers to that used for computation in shuffles, joins, + * sorts and aggregations, while storage memory refers to that used for caching and propagating + * internal data across the cluster. There exists one of these per JVM. + */ +private[spark] abstract class MemoryManager extends Logging { + + // The memory store used to evict cached blocks + private var _memoryStore: MemoryStore = _ + protected def memoryStore: MemoryStore = { + if (_memoryStore == null) { + throw new IllegalArgumentException("memory store not initialized yet") + } + _memoryStore + } + + // Amount of execution/storage memory in use, accesses must be synchronized on `this` + protected var _executionMemoryUsed: Long = 0 + protected var _storageMemoryUsed: Long = 0 + + /** + * Set the [[MemoryStore]] used by this manager to evict cached blocks. + * This must be set after construction due to initialization ordering constraints. + */ + final def setMemoryStore(store: MemoryStore): Unit = { + _memoryStore = store + } + + /** + * Total available memory for execution, in bytes. + */ + def maxExecutionMemory: Long + + /** + * Total available memory for storage, in bytes. + */ + def maxStorageMemory: Long + + // TODO: avoid passing evicted blocks around to simplify method signatures (SPARK-10985) + + /** + * Acquire N bytes of memory for execution, evicting cached blocks if necessary. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * @return number of bytes successfully granted (<= N). + */ + def acquireExecutionMemory( + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long + + /** + * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * @return whether all N bytes were successfully granted. + */ + def acquireStorageMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean + + /** + * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary. + * + * This extra method allows subclasses to differentiate behavior between acquiring storage + * memory and acquiring unroll memory. For instance, the memory management model in Spark + * 1.5 and before places a limit on the amount of space that can be freed from unrolling. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * + * @return whether all N bytes were successfully granted. + */ + def acquireUnrollMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + acquireStorageMemory(blockId, numBytes, evictedBlocks) + } + + /** + * Release N bytes of execution memory. + */ + def releaseExecutionMemory(numBytes: Long): Unit = synchronized { + if (numBytes > _executionMemoryUsed) { + logWarning(s"Attempted to release $numBytes bytes of execution " + + s"memory when we only have ${_executionMemoryUsed} bytes") + _executionMemoryUsed = 0 + } else { + _executionMemoryUsed -= numBytes + } + } + + /** + * Release N bytes of storage memory. + */ + def releaseStorageMemory(numBytes: Long): Unit = synchronized { + if (numBytes > _storageMemoryUsed) { + logWarning(s"Attempted to release $numBytes bytes of storage " + + s"memory when we only have ${_storageMemoryUsed} bytes") + _storageMemoryUsed = 0 + } else { + _storageMemoryUsed -= numBytes + } + } + + /** + * Release all storage memory acquired. + */ + def releaseAllStorageMemory(): Unit = synchronized { + _storageMemoryUsed = 0 + } + + /** + * Release N bytes of unroll memory. + */ + def releaseUnrollMemory(numBytes: Long): Unit = synchronized { + releaseStorageMemory(numBytes) + } + + /** + * Execution memory currently in use, in bytes. + */ + final def executionMemoryUsed: Long = synchronized { + _executionMemoryUsed + } + + /** + * Storage memory currently in use, in bytes. + */ + final def storageMemoryUsed: Long = synchronized { + _storageMemoryUsed + } + +} diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala new file mode 100644 index 0000000000000..fa44f3723415d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import scala.collection.mutable + +import org.apache.spark.SparkConf +import org.apache.spark.storage.{BlockId, BlockStatus} + + +/** + * A [[MemoryManager]] that statically partitions the heap space into disjoint regions. + * + * The sizes of the execution and storage regions are determined through + * `spark.shuffle.memoryFraction` and `spark.storage.memoryFraction` respectively. The two + * regions are cleanly separated such that neither usage can borrow memory from the other. + */ +private[spark] class StaticMemoryManager( + conf: SparkConf, + override val maxExecutionMemory: Long, + override val maxStorageMemory: Long) + extends MemoryManager { + + def this(conf: SparkConf) { + this( + conf, + StaticMemoryManager.getMaxExecutionMemory(conf), + StaticMemoryManager.getMaxStorageMemory(conf)) + } + + // Max number of bytes worth of blocks to evict when unrolling + private val maxMemoryToEvictForUnroll: Long = { + (maxStorageMemory * conf.getDouble("spark.storage.unrollFraction", 0.2)).toLong + } + + /** + * Acquire N bytes of memory for execution. + * @return number of bytes successfully granted (<= N). + */ + override def acquireExecutionMemory( + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { + assert(numBytes >= 0) + assert(_executionMemoryUsed <= maxExecutionMemory) + val bytesToGrant = math.min(numBytes, maxExecutionMemory - _executionMemoryUsed) + _executionMemoryUsed += bytesToGrant + bytesToGrant + } + + /** + * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * @return whether all N bytes were successfully granted. + */ + override def acquireStorageMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + acquireStorageMemory(blockId, numBytes, numBytes, evictedBlocks) + } + + /** + * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary. + * + * This evicts at most M bytes worth of existing blocks, where M is a fraction of the storage + * space specified by `spark.storage.unrollFraction`. Blocks evicted in the process, if any, + * are added to `evictedBlocks`. + * + * @return whether all N bytes were successfully granted. + */ + override def acquireUnrollMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + val currentUnrollMemory = memoryStore.currentUnrollMemory + val maxNumBytesToFree = math.max(0, maxMemoryToEvictForUnroll - currentUnrollMemory) + val numBytesToFree = math.min(numBytes, maxNumBytesToFree) + acquireStorageMemory(blockId, numBytes, numBytesToFree, evictedBlocks) + } + + /** + * Acquire N bytes of storage memory for the given block, evicting existing ones if necessary. + * + * @param blockId the ID of the block we are acquiring storage memory for + * @param numBytesToAcquire the size of this block + * @param numBytesToFree the size of space to be freed through evicting blocks + * @param evictedBlocks a holder for blocks evicted in the process + * @return whether all N bytes were successfully granted. + */ + private def acquireStorageMemory( + blockId: BlockId, + numBytesToAcquire: Long, + numBytesToFree: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + assert(numBytesToAcquire >= 0) + assert(numBytesToFree >= 0) + memoryStore.ensureFreeSpace(blockId, numBytesToFree, evictedBlocks) + assert(_storageMemoryUsed <= maxStorageMemory) + val enoughMemory = _storageMemoryUsed + numBytesToAcquire <= maxStorageMemory + if (enoughMemory) { + _storageMemoryUsed += numBytesToAcquire + } + enoughMemory + } + +} + + +private[spark] object StaticMemoryManager { + + /** + * Return the total amount of memory available for the storage region, in bytes. + */ + private def getMaxStorageMemory(conf: SparkConf): Long = { + val systemMaxMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6) + val safetyFraction = conf.getDouble("spark.storage.safetyFraction", 0.9) + (systemMaxMemory * memoryFraction * safetyFraction).toLong + } + + + /** + * Return the total amount of memory available for the execution region, in bytes. + */ + private def getMaxExecutionMemory(conf: SparkConf): Long = { + val systemMaxMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) + val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) + (systemMaxMemory * memoryFraction * safetyFraction).toLong + } + +} diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala new file mode 100644 index 0000000000000..5bf78d5b674b3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import scala.collection.mutable + +import org.apache.spark.SparkConf +import org.apache.spark.storage.{BlockStatus, BlockId} + + +/** + * A [[MemoryManager]] that enforces a soft boundary between execution and storage such that + * either side can borrow memory from the other. + * + * The region shared between execution and storage is a fraction of the total heap space + * configurable through `spark.memory.fraction` (default 0.75). The position of the boundary + * within this space is further determined by `spark.memory.storageFraction` (default 0.5). + * This means the size of the storage region is 0.75 * 0.5 = 0.375 of the heap space by default. + * + * Storage can borrow as much execution memory as is free until execution reclaims its space. + * When this happens, cached blocks will be evicted from memory until sufficient borrowed + * memory is released to satisfy the execution memory request. + * + * Similarly, execution can borrow as much storage memory as is free. However, execution + * memory is *never* evicted by storage due to the complexities involved in implementing this. + * The implication is that attempts to cache blocks may fail if execution has already eaten + * up most of the storage space, in which case the new blocks will be evicted immediately + * according to their respective storage levels. + */ +private[spark] class UnifiedMemoryManager(conf: SparkConf, maxMemory: Long) extends MemoryManager { + + def this(conf: SparkConf) { + this(conf, UnifiedMemoryManager.getMaxMemory(conf)) + } + + /** + * Size of the storage region, in bytes. + * + * This region is not statically reserved; execution can borrow from it if necessary. + * Cached blocks can be evicted only if actual storage memory usage exceeds this region. + */ + private val storageRegionSize: Long = { + (maxMemory * conf.getDouble("spark.memory.storageFraction", 0.5)).toLong + } + + /** + * Total amount of memory, in bytes, not currently occupied by either execution or storage. + */ + private def totalFreeMemory: Long = synchronized { + assert(_executionMemoryUsed <= maxMemory) + assert(_storageMemoryUsed <= maxMemory) + assert(_executionMemoryUsed + _storageMemoryUsed <= maxMemory) + maxMemory - _executionMemoryUsed - _storageMemoryUsed + } + + /** + * Total available memory for execution, in bytes. + * In this model, this is equivalent to the amount of memory not occupied by storage. + */ + override def maxExecutionMemory: Long = synchronized { + maxMemory - _storageMemoryUsed + } + + /** + * Total available memory for storage, in bytes. + * In this model, this is equivalent to the amount of memory not occupied by execution. + */ + override def maxStorageMemory: Long = synchronized { + maxMemory - _executionMemoryUsed + } + + /** + * Acquire N bytes of memory for execution, evicting cached blocks if necessary. + * + * This method evicts blocks only up to the amount of memory borrowed by storage. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * @return number of bytes successfully granted (<= N). + */ + override def acquireExecutionMemory( + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { + assert(numBytes >= 0) + val memoryBorrowedByStorage = math.max(0, _storageMemoryUsed - storageRegionSize) + // If there is not enough free memory AND storage has borrowed some execution memory, + // then evict as much memory borrowed by storage as needed to grant this request + val shouldEvictStorage = totalFreeMemory < numBytes && memoryBorrowedByStorage > 0 + if (shouldEvictStorage) { + val spaceToEnsure = math.min(numBytes, memoryBorrowedByStorage) + memoryStore.ensureFreeSpace(spaceToEnsure, evictedBlocks) + } + val bytesToGrant = math.min(numBytes, totalFreeMemory) + _executionMemoryUsed += bytesToGrant + bytesToGrant + } + + /** + * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * @return whether all N bytes were successfully granted. + */ + override def acquireStorageMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + assert(numBytes >= 0) + memoryStore.ensureFreeSpace(blockId, numBytes, evictedBlocks) + val enoughMemory = totalFreeMemory >= numBytes + if (enoughMemory) { + _storageMemoryUsed += numBytes + } + enoughMemory + } + +} + +private object UnifiedMemoryManager { + + /** + * Return the total amount of memory shared between execution and storage, in bytes. + */ + private def getMaxMemory(conf: SparkConf): Long = { + val systemMaxMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) + val memoryFraction = conf.getDouble("spark.memory.fraction", 0.75) + (systemMaxMemory * memoryFraction).toLong + } +} diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 4517f465ebd3b..48afe3ae3511f 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -88,7 +88,7 @@ private[spark] class MetricsSystem private ( */ def getServletHandlers: Array[ServletContextHandler] = { require(running, "Can only call getServletHandlers on a running MetricsSystem") - metricsServlet.map(_.getHandlers).getOrElse(Array()) + metricsServlet.map(_.getHandlers(conf)).getOrElse(Array()) } metricsConfig.initialize() diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 0c2e212a33074..4193e1d21d3c1 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -27,7 +27,7 @@ import com.codahale.metrics.json.MetricsModule import com.fasterxml.jackson.databind.ObjectMapper import org.eclipse.jetty.servlet.ServletContextHandler -import org.apache.spark.SecurityManager +import org.apache.spark.{SparkConf, SecurityManager} import org.apache.spark.ui.JettyUtils._ private[spark] class MetricsServlet( @@ -49,10 +49,10 @@ private[spark] class MetricsServlet( val mapper = new ObjectMapper().registerModule( new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - def getHandlers: Array[ServletContextHandler] = { + def getHandlers(conf: SparkConf): Array[ServletContextHandler] = { Array[ServletContextHandler]( createServletHandler(servletPath, - new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) + new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr, conf) ) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 7c170a742fb64..76968249fb625 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -38,6 +38,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel} * is equivalent to one Spark-level shuffle block. */ class NettyBlockRpcServer( + appId: String, serializer: Serializer, blockManager: BlockDataManager) extends RpcHandler with Logging { @@ -55,7 +56,7 @@ class NettyBlockRpcServer( case openBlocks: OpenBlocks => val blocks: Seq[ManagedBuffer] = openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) - val streamId = streamManager.registerStream(blocks.iterator.asJava) + val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index ff8aae9ebe9f0..70a42f9045e6b 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -49,7 +49,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage private[this] var appId: String = _ override def init(blockDataManager: BlockDataManager): Unit = { - val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager) var serverBootstrap: Option[TransportServerBootstrap] = None var clientBootstrap: Option[TransportClientBootstrap] = None if (authEnabled) { @@ -137,7 +137,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage new RpcResponseCallback { override def onSuccess(response: Array[Byte]): Unit = { logTrace(s"Successfully uploaded block $blockId") - result.success() + result.success((): Unit) } override def onFailure(e: Throwable): Unit = { logError(s"Error while uploading block $blockId", e) @@ -149,7 +149,11 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage } override def close(): Unit = { - server.close() - clientFactory.close() + if (server != null) { + server.close() + } + if (clientFactory != null) { + clientFactory.close() + } } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala deleted file mode 100644 index 79cb0640c8672..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} - -import scala.collection.mutable.{ArrayBuffer, StringBuilder} - -// private[spark] because we need to register them in Kryo -private[spark] case class GetBlock(id: BlockId) -private[spark] case class GotBlock(id: BlockId, data: ByteBuffer) -private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel) - -private[nio] class BlockMessage() { - // Un-initialized: typ = 0 - // GetBlock: typ = 1 - // GotBlock: typ = 2 - // PutBlock: typ = 3 - private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED - private var id: BlockId = null - private var data: ByteBuffer = null - private var level: StorageLevel = null - - def set(getBlock: GetBlock) { - typ = BlockMessage.TYPE_GET_BLOCK - id = getBlock.id - } - - def set(gotBlock: GotBlock) { - typ = BlockMessage.TYPE_GOT_BLOCK - id = gotBlock.id - data = gotBlock.data - } - - def set(putBlock: PutBlock) { - typ = BlockMessage.TYPE_PUT_BLOCK - id = putBlock.id - data = putBlock.data - level = putBlock.level - } - - def set(buffer: ByteBuffer) { - typ = buffer.getInt() - val idLength = buffer.getInt() - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buffer.getChar() - } - id = BlockId(idBuilder.toString) - - if (typ == BlockMessage.TYPE_PUT_BLOCK) { - - val booleanInt = buffer.getInt() - val replication = buffer.getInt() - level = StorageLevel(booleanInt, replication) - - val dataLength = buffer.getInt() - data = ByteBuffer.allocate(dataLength) - if (dataLength != buffer.remaining) { - throw new Exception("Error parsing buffer") - } - data.put(buffer) - data.flip() - } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { - - val dataLength = buffer.getInt() - data = ByteBuffer.allocate(dataLength) - if (dataLength != buffer.remaining) { - throw new Exception("Error parsing buffer") - } - data.put(buffer) - data.flip() - } - - } - - def set(bufferMsg: BufferMessage) { - val buffer = bufferMsg.buffers.apply(0) - buffer.clear() - set(buffer) - } - - def getType: Int = typ - def getId: BlockId = id - def getData: ByteBuffer = data - def getLevel: StorageLevel = level - - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2) - buffer.putInt(typ).putInt(id.name.length) - id.name.foreach((x: Char) => buffer.putChar(x)) - buffer.flip() - buffers += buffer - - if (typ == BlockMessage.TYPE_PUT_BLOCK) { - buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication) - buffer.flip() - buffers += buffer - - buffer = ByteBuffer.allocate(4).putInt(data.remaining) - buffer.flip() - buffers += buffer - - buffers += data - } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { - buffer = ByteBuffer.allocate(4).putInt(data.remaining) - buffer.flip() - buffers += buffer - - buffers += data - } - - Message.createBufferMessage(buffers) - } - - override def toString: String = { - "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + - ", data = " + (if (data != null) data.remaining.toString else "null") + "]" - } -} - -private[nio] object BlockMessage { - val TYPE_NON_INITIALIZED: Int = 0 - val TYPE_GET_BLOCK: Int = 1 - val TYPE_GOT_BLOCK: Int = 2 - val TYPE_PUT_BLOCK: Int = 3 - - def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(bufferMessage) - newBlockMessage - } - - def fromByteBuffer(buffer: ByteBuffer): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(buffer) - newBlockMessage - } - - def fromGetBlock(getBlock: GetBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(getBlock) - newBlockMessage - } - - def fromGotBlock(gotBlock: GotBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(gotBlock) - newBlockMessage - } - - def fromPutBlock(putBlock: PutBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(putBlock) - newBlockMessage - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala deleted file mode 100644 index f1c9ea8b64ca3..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import org.apache.spark._ -import org.apache.spark.storage.{StorageLevel, TestBlockId} - -import scala.collection.mutable.ArrayBuffer - -private[nio] -class BlockMessageArray(var blockMessages: Seq[BlockMessage]) - extends Seq[BlockMessage] with Logging { - - def this(bm: BlockMessage) = this(Array(bm)) - - def this() = this(null.asInstanceOf[Seq[BlockMessage]]) - - def apply(i: Int): BlockMessage = blockMessages(i) - - def iterator: Iterator[BlockMessage] = blockMessages.iterator - - def length: Int = blockMessages.length - - def set(bufferMessage: BufferMessage) { - val startTime = System.currentTimeMillis - val newBlockMessages = new ArrayBuffer[BlockMessage]() - val buffer = bufferMessage.buffers(0) - buffer.clear() - while (buffer.remaining() > 0) { - val size = buffer.getInt() - logDebug("Creating block message of size " + size + " bytes") - val newBuffer = buffer.slice() - newBuffer.clear() - newBuffer.limit(size) - logDebug("Trying to convert buffer " + newBuffer + " to block message") - val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer) - logDebug("Created " + newBlockMessage) - newBlockMessages += newBlockMessage - buffer.position(buffer.position() + size) - } - val finishTime = System.currentTimeMillis - logDebug("Converted block message array from buffer message in " + - (finishTime - startTime) / 1000.0 + " s") - this.blockMessages = newBlockMessages - } - - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - - blockMessages.foreach(blockMessage => { - val bufferMessage = blockMessage.toBufferMessage - logDebug("Adding " + blockMessage) - val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size) - sizeBuffer.flip - buffers += sizeBuffer - buffers ++= bufferMessage.buffers - logDebug("Added " + bufferMessage) - }) - - logDebug("Buffer list:") - buffers.foreach((x: ByteBuffer) => logDebug("" + x)) - Message.createBufferMessage(buffers) - } -} - -private[nio] object BlockMessageArray extends Logging { - - def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { - val newBlockMessageArray = new BlockMessageArray() - newBlockMessageArray.set(bufferMessage) - newBlockMessageArray - } - - def main(args: Array[String]) { - val blockMessages = - (0 until 10).map { i => - if (i % 2 == 0) { - val buffer = ByteBuffer.allocate(100) - buffer.clear() - BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer, - StorageLevel.MEMORY_ONLY_SER)) - } else { - BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString))) - } - } - val blockMessageArray = new BlockMessageArray(blockMessages) - logDebug("Block message array created") - - val bufferMessage = blockMessageArray.toBufferMessage - logDebug("Converted to buffer message") - - val totalSize = bufferMessage.size - val newBuffer = ByteBuffer.allocate(totalSize) - newBuffer.clear() - bufferMessage.buffers.foreach(buffer => { - assert (0 == buffer.position()) - newBuffer.put(buffer) - buffer.rewind() - }) - newBuffer.flip - val newBufferMessage = Message.createBufferMessage(newBuffer) - logDebug("Copied to new buffer message, size = " + newBufferMessage.size) - - val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) - logDebug("Converted back to block message array") - // scalastyle:off println - newBlockMessageArray.foreach(blockMessage => { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => { - val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - println(pB) - } - case BlockMessage.TYPE_GET_BLOCK => { - val gB = new GetBlock(blockMessage.getId) - println(gB) - } - } - }) - // scalastyle:on println - } -} - - diff --git a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala deleted file mode 100644 index 9a9e22b0c2366..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.storage.BlockManager - - -private[nio] -class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) - extends Message(Message.BUFFER_MESSAGE, id_) { - - val initialSize = currentSize() - var gotChunkForSendingOnce = false - - def size: Int = initialSize - - def currentSize(): Int = { - if (buffers == null || buffers.isEmpty) { - 0 - } else { - buffers.map(_.remaining).reduceLeft(_ + _) - } - } - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { - if (maxChunkSize <= 0) { - throw new Exception("Max chunk size is " + maxChunkSize) - } - - val security = if (isSecurityNeg) 1 else 0 - if (size == 0 && !gotChunkForSendingOnce) { - val newChunk = new MessageChunk( - new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null) - gotChunkForSendingOnce = true - return Some(newChunk) - } - - while(!buffers.isEmpty) { - val buffer = buffers(0) - if (buffer.remaining == 0) { - BlockManager.dispose(buffer) - buffers -= buffer - } else { - val newBuffer = if (buffer.remaining <= maxChunkSize) { - buffer.duplicate() - } else { - buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] - } - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, - hasError, security, senderAddress), newBuffer) - gotChunkForSendingOnce = true - return Some(newChunk) - } - } - None - } - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { - // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer - if (buffers.size > 1) { - throw new Exception("Attempting to get chunk from message with multiple data buffers") - } - val buffer = buffers(0) - val security = if (isSecurityNeg) 1 else 0 - if (buffer.remaining > 0) { - if (buffer.remaining < chunkSize) { - throw new Exception("Not enough space in data buffer for receiving chunk") - } - val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer) - return Some(newChunk) - } - None - } - - def flip() { - buffers.foreach(_.flip) - } - - def hasAckId(): Boolean = ackId != 0 - - def isCompletelyReceived: Boolean = !buffers(0).hasRemaining - - override def toString: String = { - if (hasAckId) { - "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" - } else { - "BufferMessage(id = " + id + ", size = " + size + ")" - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala deleted file mode 100644 index 8d9ebadaf79d4..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ /dev/null @@ -1,619 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.net._ -import java.nio._ -import java.nio.channels._ -import java.util.concurrent.ConcurrentLinkedQueue -import java.util.LinkedList - -import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.util.control.NonFatal - -import org.apache.spark._ -import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} - -private[nio] -abstract class Connection(val channel: SocketChannel, val selector: Selector, - val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId, - val securityMgr: SecurityManager) - extends Logging { - - var sparkSaslServer: SparkSaslServer = null - var sparkSaslClient: SparkSaslClient = null - - def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId, - securityMgr_ : SecurityManager) = { - this(channel_, selector_, - ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), - id_, securityMgr_) - } - - channel.configureBlocking(false) - channel.socket.setTcpNoDelay(true) - channel.socket.setReuseAddress(true) - channel.socket.setKeepAlive(true) - /* channel.socket.setReceiveBufferSize(32768) */ - - @volatile private var closed = false - var onCloseCallback: Connection => Unit = null - val onExceptionCallbacks = new ConcurrentLinkedQueue[(Connection, Throwable) => Unit] - var onKeyInterestChangeCallback: (Connection, Int) => Unit = null - - val remoteAddress = getRemoteAddress() - - def isSaslComplete(): Boolean - - def resetForceReregister(): Boolean - - // Read channels typically do not register for write and write does not for read - // Now, we do have write registering for read too (temporarily), but this is to detect - // channel close NOT to actually read/consume data on it ! - // How does this work if/when we move to SSL ? - - // What is the interest to register with selector for when we want this connection to be selected - def registerInterest() - - // What is the interest to register with selector for when we want this connection to - // be de-selected - // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, - // it will be SelectionKey.OP_READ (until we fix it properly) - def unregisterInterest() - - // On receiving a read event, should we change the interest for this channel or not ? - // Will be true for ReceivingConnection, false for SendingConnection. - def changeInterestForRead(): Boolean - - private def disposeSasl() { - if (sparkSaslServer != null) { - sparkSaslServer.dispose() - } - - if (sparkSaslClient != null) { - sparkSaslClient.dispose() - } - } - - // On receiving a write event, should we change the interest for this channel or not ? - // Will be false for ReceivingConnection, true for SendingConnection. - // Actually, for now, should not get triggered for ReceivingConnection - def changeInterestForWrite(): Boolean - - def getRemoteConnectionManagerId(): ConnectionManagerId = { - socketRemoteConnectionManagerId - } - - def key(): SelectionKey = channel.keyFor(selector) - - def getRemoteAddress(): InetSocketAddress = { - channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] - } - - // Returns whether we have to register for further reads or not. - def read(): Boolean = { - throw new UnsupportedOperationException( - "Cannot read on connection of type " + this.getClass.toString) - } - - // Returns whether we have to register for further writes or not. - def write(): Boolean = { - throw new UnsupportedOperationException( - "Cannot write on connection of type " + this.getClass.toString) - } - - def close() { - closed = true - val k = key() - if (k != null) { - k.cancel() - } - channel.close() - disposeSasl() - callOnCloseCallback() - } - - protected def isClosed: Boolean = closed - - def onClose(callback: Connection => Unit) { - onCloseCallback = callback - } - - def onException(callback: (Connection, Throwable) => Unit) { - onExceptionCallbacks.add(callback) - } - - def onKeyInterestChange(callback: (Connection, Int) => Unit) { - onKeyInterestChangeCallback = callback - } - - def callOnExceptionCallbacks(e: Throwable) { - onExceptionCallbacks.asScala.foreach { - callback => - try { - callback(this, e) - } catch { - case NonFatal(e) => { - logWarning("Ignored error in onExceptionCallback", e) - } - } - } - } - - def callOnCloseCallback() { - if (onCloseCallback != null) { - onCloseCallback(this) - } else { - logWarning("Connection to " + getRemoteConnectionManagerId() + - " closed and OnExceptionCallback not registered") - } - - } - - def changeConnectionKeyInterest(ops: Int) { - if (onKeyInterestChangeCallback != null) { - onKeyInterestChangeCallback(this, ops) - } else { - throw new Exception("OnKeyInterestChangeCallback not registered") - } - } - - def printRemainingBuffer(buffer: ByteBuffer) { - val bytes = new Array[Byte](buffer.remaining) - val curPosition = buffer.position - buffer.get(bytes) - bytes.foreach(x => print(x + " ")) - buffer.position(curPosition) - print(" (" + bytes.length + ")") - } - - def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { - val bytes = new Array[Byte](length) - val curPosition = buffer.position - buffer.position(position) - buffer.get(bytes) - bytes.foreach(x => print(x + " ")) - print(" (" + position + ", " + length + ")") - buffer.position(curPosition) - } -} - - -private[nio] -class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId, id_ : ConnectionId, - securityMgr_ : SecurityManager) - extends Connection(SocketChannel.open, selector_, remoteId_, id_, securityMgr_) { - - def isSaslComplete(): Boolean = { - if (sparkSaslClient != null) sparkSaslClient.isComplete() else false - } - - private class Outbox { - val messages = new LinkedList[Message]() - val defaultChunkSize = 65536 - var nextMessageToBeUsed = 0 - - def addMessage(message: Message) { - messages.synchronized { - messages.add(message) - logDebug("Added [" + message + "] to outbox for sending to " + - "[" + getRemoteConnectionManagerId() + "]") - } - } - - def getChunk(): Option[MessageChunk] = { - messages.synchronized { - while (!messages.isEmpty) { - /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ - /* val message = messages(nextMessageToBeUsed) */ - - val message = if (securityMgr.isAuthenticationEnabled() && !isSaslComplete()) { - // only allow sending of security messages until sasl is complete - var pos = 0 - var securityMsg: Message = null - while (pos < messages.size() && securityMsg == null) { - if (messages.get(pos).isSecurityNeg) { - securityMsg = messages.remove(pos) - } - pos = pos + 1 - } - // didn't find any security messages and auth isn't completed so return - if (securityMsg == null) return None - securityMsg - } else { - messages.removeFirst() - } - - val chunk = message.getChunkForSending(defaultChunkSize) - if (chunk.isDefined) { - messages.add(message) - nextMessageToBeUsed = nextMessageToBeUsed + 1 - if (!message.started) { - logDebug( - "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") - message.started = true - message.startTime = System.currentTimeMillis - } - logTrace( - "Sending chunk from [" + message + "] to [" + getRemoteConnectionManagerId() + "]") - return chunk - } else { - message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + - "] in " + message.timeTaken ) - } - } - } - None - } - } - - // outbox is used as a lock - ensure that it is always used as a leaf (since methods which - // lock it are invoked in context of other locks) - private val outbox = new Outbox() - /* - This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly - different purpose. This flag is to see if we need to force reregister for write even when we - do not have any pending bytes to write to socket. - This can happen due to a race between adding pending buffers, and checking for existing of - data as detailed in https://github.com/mesos/spark/pull/791 - */ - private var needForceReregister = false - - val currentBuffers = new ArrayBuffer[ByteBuffer]() - - /* channel.socket.setSendBufferSize(256 * 1024) */ - - override def getRemoteAddress(): InetSocketAddress = address - - val DEFAULT_INTEREST = SelectionKey.OP_READ - - override def registerInterest() { - // Registering read too - does not really help in most cases, but for some - // it does - so let us keep it for now. - changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST) - } - - override def unregisterInterest() { - changeConnectionKeyInterest(DEFAULT_INTEREST) - } - - def registerAfterAuth(): Unit = { - outbox.synchronized { - needForceReregister = true - } - if (channel.isConnected) { - registerInterest() - } - } - - def send(message: Message) { - outbox.synchronized { - outbox.addMessage(message) - needForceReregister = true - } - if (channel.isConnected) { - registerInterest() - } - } - - // return previous value after resetting it. - def resetForceReregister(): Boolean = { - outbox.synchronized { - val result = needForceReregister - needForceReregister = false - result - } - } - - // MUST be called within the selector loop - def connect() { - try { - channel.register(selector, SelectionKey.OP_CONNECT) - channel.connect(address) - logInfo("Initiating connection to [" + address + "]") - } catch { - case e: Exception => - logError("Error connecting to " + address, e) - callOnExceptionCallbacks(e) - } - } - - def finishConnect(force: Boolean): Boolean = { - try { - // Typically, this should finish immediately since it was triggered by a connect - // selection - though need not necessarily always complete successfully. - val connected = channel.finishConnect - if (!force && !connected) { - logInfo( - "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") - return false - } - - // Fallback to previous behavior - assume finishConnect completed - // This will happen only when finishConnect failed for some repeated number of times - // (10 or so) - // Is highly unlikely unless there was an unclean close of socket, etc - registerInterest() - logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") - } catch { - case e: Exception => { - logWarning("Error finishing connection to " + address, e) - callOnExceptionCallbacks(e) - } - } - true - } - - override def write(): Boolean = { - try { - while (true) { - if (currentBuffers.size == 0) { - outbox.synchronized { - outbox.getChunk() match { - case Some(chunk) => { - val buffers = chunk.buffers - // If we have 'seen' pending messages, then reset flag - since we handle that as - // normal registering of event (below) - if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister() - - currentBuffers ++= buffers - } - case None => { - // changeConnectionKeyInterest(0) - /* key.interestOps(0) */ - return false - } - } - } - } - - if (currentBuffers.size > 0) { - val buffer = currentBuffers(0) - val remainingBytes = buffer.remaining - val writtenBytes = channel.write(buffer) - if (buffer.remaining == 0) { - currentBuffers -= buffer - } - if (writtenBytes < remainingBytes) { - // re-register for write. - return true - } - } - } - } catch { - case e: Exception => { - logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallbacks(e) - close() - return false - } - } - // should not happen - to keep scala compiler happy - true - } - - // This is a hack to determine if remote socket was closed or not. - // SendingConnection DOES NOT expect to receive any data - if it does, it is an error - // For a bunch of cases, read will return -1 in case remote socket is closed : hence we - // register for reads to determine that. - override def read(): Boolean = { - // We don't expect the other side to send anything; so, we just read to detect an error or EOF. - try { - val length = channel.read(ByteBuffer.allocate(1)) - if (length == -1) { // EOF - close() - } else if (length > 0) { - logWarning( - "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) - } - } catch { - case e: Exception => - logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), - e) - callOnExceptionCallbacks(e) - close() - } - - false - } - - override def changeInterestForRead(): Boolean = false - - override def changeInterestForWrite(): Boolean = ! isClosed -} - - -// Must be created within selector loop - else deadlock -private[spark] class ReceivingConnection( - channel_ : SocketChannel, - selector_ : Selector, - id_ : ConnectionId, - securityMgr_ : SecurityManager) - extends Connection(channel_, selector_, id_, securityMgr_) { - - def isSaslComplete(): Boolean = { - if (sparkSaslServer != null) sparkSaslServer.isComplete() else false - } - - class Inbox() { - val messages = new HashMap[Int, BufferMessage]() - - def getChunk(header: MessageChunkHeader): Option[MessageChunk] = { - - def createNewMessage: BufferMessage = { - val newMessage = Message.create(header).asInstanceOf[BufferMessage] - newMessage.started = true - newMessage.startTime = System.currentTimeMillis - newMessage.isSecurityNeg = header.securityNeg == 1 - logDebug( - "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") - messages += ((newMessage.id, newMessage)) - newMessage - } - - val message = messages.getOrElseUpdate(header.id, createNewMessage) - logTrace( - "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") - message.getChunkForReceiving(header.chunkSize) - } - - def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = { - messages.get(chunk.header.id) - } - - def removeMessage(message: Message) { - messages -= message.id - } - } - - @volatile private var inferredRemoteManagerId: ConnectionManagerId = null - - override def getRemoteConnectionManagerId(): ConnectionManagerId = { - val currId = inferredRemoteManagerId - if (currId != null) currId else super.getRemoteConnectionManagerId() - } - - // The receiver's remote address is the local socket on remote side : which is NOT - // the connection manager id of the receiver. - // We infer that from the messages we receive on the receiver socket. - private def processConnectionManagerId(header: MessageChunkHeader) { - val currId = inferredRemoteManagerId - if (header.address == null || currId != null) return - - val managerId = ConnectionManagerId.fromSocketAddress(header.address) - - if (managerId != null) { - inferredRemoteManagerId = managerId - } - } - - - val inbox = new Inbox() - val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) - var onReceiveCallback: (Connection, Message) => Unit = null - var currentChunk: MessageChunk = null - - channel.register(selector, SelectionKey.OP_READ) - - override def read(): Boolean = { - try { - while (true) { - if (currentChunk == null) { - val headerBytesRead = channel.read(headerBuffer) - if (headerBytesRead == -1) { - close() - return false - } - if (headerBuffer.remaining > 0) { - // re-register for read event ... - return true - } - headerBuffer.flip - if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { - throw new Exception( - "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") - } - val header = MessageChunkHeader.create(headerBuffer) - headerBuffer.clear() - - processConnectionManagerId(header) - - header.typ match { - case Message.BUFFER_MESSAGE => { - if (header.totalSize == 0) { - if (onReceiveCallback != null) { - onReceiveCallback(this, Message.create(header)) - } - currentChunk = null - // re-register for read event ... - return true - } else { - currentChunk = inbox.getChunk(header).orNull - } - } - case _ => throw new Exception("Message of unknown type received") - } - } - - if (currentChunk == null) throw new Exception("No message chunk to receive data") - - val bytesRead = channel.read(currentChunk.buffer) - if (bytesRead == 0) { - // re-register for read event ... - return true - } else if (bytesRead == -1) { - close() - return false - } - - /* logDebug("Read " + bytesRead + " bytes for the buffer") */ - - if (currentChunk.buffer.remaining == 0) { - /* println("Filled buffer at " + System.currentTimeMillis) */ - val bufferMessage = inbox.getMessageForChunk(currentChunk).get - if (bufferMessage.isCompletelyReceived) { - bufferMessage.flip() - bufferMessage.finishTime = System.currentTimeMillis - logDebug("Finished receiving [" + bufferMessage + "] from " + - "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) - if (onReceiveCallback != null) { - onReceiveCallback(this, bufferMessage) - } - inbox.removeMessage(bufferMessage) - } - currentChunk = null - } - } - } catch { - case e: Exception => { - logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallbacks(e) - close() - return false - } - } - // should not happen - to keep scala compiler happy - true - } - - def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} - - // override def changeInterestForRead(): Boolean = ! isClosed - override def changeInterestForRead(): Boolean = true - - override def changeInterestForWrite(): Boolean = { - throw new IllegalStateException("Unexpected invocation right now") - } - - override def registerInterest() { - // Registering read too - does not really help in most cases, but for some - // it does - so let us keep it for now. - changeConnectionKeyInterest(SelectionKey.OP_READ) - } - - override def unregisterInterest() { - changeConnectionKeyInterest(0) - } - - // For read conn, always false. - override def resetForceReregister(): Boolean = false -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala deleted file mode 100644 index 9143918790381..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ /dev/null @@ -1,1157 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.io.IOException -import java.lang.ref.WeakReference -import java.net._ -import java.nio._ -import java.nio.channels._ -import java.nio.channels.spi._ -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} -import scala.concurrent.duration._ -import scala.concurrent.{Await, ExecutionContext, Future, Promise} -import scala.language.postfixOps - -import com.google.common.base.Charsets.UTF_8 -import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} - -import org.apache.spark._ -import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} -import org.apache.spark.util.{ThreadUtils, Utils} - -import scala.util.Try -import scala.util.control.NonFatal - -private[nio] class ConnectionManager( - port: Int, - conf: SparkConf, - securityManager: SecurityManager, - name: String = "Connection manager") - extends Logging { - - /** - * Used by sendMessageReliably to track messages being sent. - * @param message the message that was sent - * @param connectionManagerId the connection manager that sent this message - * @param completionHandler callback that's invoked when the send has completed or failed - */ - class MessageStatus( - val message: Message, - val connectionManagerId: ConnectionManagerId, - completionHandler: Try[Message] => Unit) { - - def success(ackMessage: Message) { - if (ackMessage == null) { - failure(new NullPointerException) - } - else { - completionHandler(scala.util.Success(ackMessage)) - } - } - - def failWithoutAck() { - completionHandler(scala.util.Failure(new IOException("Failed without being ACK'd"))) - } - - def failure(e: Throwable) { - completionHandler(scala.util.Failure(e)) - } - } - - private val selector = SelectorProvider.provider.openSelector() - private val ackTimeoutMonitor = - new HashedWheelTimer(ThreadUtils.namedThreadFactory("AckTimeoutMonitor")) - - private val ackTimeout = - conf.getTimeAsSeconds("spark.core.connection.ack.wait.timeout", - conf.get("spark.network.timeout", "120s")) - - // Get the thread counts from the Spark Configuration. - // - // Even though the ThreadPoolExecutor constructor takes both a minimum and maximum value, - // we only query for the minimum value because we are using LinkedBlockingDeque. - // - // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is - // an unbounded queue) no more than corePoolSize threads will ever be created, so only the "min" - // parameter is necessary. - private val handlerThreadCount = conf.getInt("spark.core.connection.handler.threads.min", 20) - private val ioThreadCount = conf.getInt("spark.core.connection.io.threads.min", 4) - private val connectThreadCount = conf.getInt("spark.core.connection.connect.threads.min", 1) - - private val handleMessageExecutor = new ThreadPoolExecutor( - handlerThreadCount, - handlerThreadCount, - conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable](), - ThreadUtils.namedThreadFactory("handle-message-executor")) { - - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - if (t != null && NonFatal(t)) { - logError("Error in handleMessageExecutor is not handled properly", t) - } - } - } - - private val handleReadWriteExecutor = new ThreadPoolExecutor( - ioThreadCount, - ioThreadCount, - conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable](), - ThreadUtils.namedThreadFactory("handle-read-write-executor")) { - - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - if (t != null && NonFatal(t)) { - logError("Error in handleReadWriteExecutor is not handled properly", t) - } - } - } - - // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : - // which should be executed asap - private val handleConnectExecutor = new ThreadPoolExecutor( - connectThreadCount, - connectThreadCount, - conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable](), - ThreadUtils.namedThreadFactory("handle-connect-executor")) { - - override def afterExecute(r: Runnable, t: Throwable): Unit = { - super.afterExecute(r, t) - if (t != null && NonFatal(t)) { - logError("Error in handleConnectExecutor is not handled properly", t) - } - } - } - - private val serverChannel = ServerSocketChannel.open() - // used to track the SendingConnections waiting to do SASL negotiation - private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection] - with SynchronizedMap[ConnectionId, SendingConnection] - private val connectionsByKey = - new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] - private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] - with SynchronizedMap[ConnectionManagerId, SendingConnection] - // Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this - // map when messages are sent and are removed when acknowledgement messages are received or when - // acknowledgement timeouts expire - private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus] - private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] - private val registerRequests = new SynchronizedQueue[SendingConnection] - - implicit val futureExecContext = ExecutionContext.fromExecutor( - ThreadUtils.newDaemonCachedThreadPool("Connection manager future execution context")) - - @volatile - private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message] = null - - private val authEnabled = securityManager.isAuthenticationEnabled() - - serverChannel.configureBlocking(false) - serverChannel.socket.setReuseAddress(true) - serverChannel.socket.setReceiveBufferSize(256 * 1024) - - private def startService(port: Int): (ServerSocketChannel, Int) = { - serverChannel.socket.bind(new InetSocketAddress(port)) - (serverChannel, serverChannel.socket.getLocalPort) - } - Utils.startServiceOnPort[ServerSocketChannel](port, startService, conf, name) - serverChannel.register(selector, SelectionKey.OP_ACCEPT) - - val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) - logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) - - // used in combination with the ConnectionManagerId to create unique Connection ids - // to be able to track asynchronous messages - private val idCount: AtomicInteger = new AtomicInteger(1) - - private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - - @volatile private var isActive = true - private val selectorThread = new Thread("connection-manager-thread") { - override def run(): Unit = ConnectionManager.this.run() - } - selectorThread.setDaemon(true) - // start this thread last, since it invokes run(), which accesses members above - selectorThread.start() - - private def triggerWrite(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - writeRunnableStarted.synchronized { - // So that we do not trigger more write events while processing this one. - // The write method will re-register when done. - if (conn.changeInterestForWrite()) conn.unregisterInterest() - if (writeRunnableStarted.contains(key)) { - // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE) - return - } - - writeRunnableStarted += key - } - handleReadWriteExecutor.execute(new Runnable { - override def run() { - try { - var register: Boolean = false - try { - register = conn.write() - } finally { - writeRunnableStarted.synchronized { - writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() - if (needReregister && conn.changeInterestForWrite()) { - conn.registerInterest() - } - } - } - } catch { - case NonFatal(e) => { - logError("Error when writing to " + conn.getRemoteConnectionManagerId(), e) - conn.callOnExceptionCallbacks(e) - } - } - } - } ) - } - - - private def triggerRead(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - readRunnableStarted.synchronized { - // So that we do not trigger more read events while processing this one. - // The read method will re-register when done. - if (conn.changeInterestForRead())conn.unregisterInterest() - if (readRunnableStarted.contains(key)) { - return - } - - readRunnableStarted += key - } - handleReadWriteExecutor.execute(new Runnable { - override def run() { - try { - var register: Boolean = false - try { - register = conn.read() - } finally { - readRunnableStarted.synchronized { - readRunnableStarted -= key - if (register && conn.changeInterestForRead()) { - conn.registerInterest() - } - } - } - } catch { - case NonFatal(e) => { - logError("Error when reading from " + conn.getRemoteConnectionManagerId(), e) - conn.callOnExceptionCallbacks(e) - } - } - } - } ) - } - - private def triggerConnect(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection] - if (conn == null) return - - // prevent other events from being triggered - // Since we are still trying to connect, we do not need to do the additional steps in - // triggerWrite - conn.changeConnectionKeyInterest(0) - - handleConnectExecutor.execute(new Runnable { - override def run() { - try { - var tries: Int = 10 - while (tries >= 0) { - if (conn.finishConnect(false)) return - // Sleep ? - Thread.sleep(1) - tries -= 1 - } - - // fallback to previous behavior : we should not really come here since this method was - // triggered since channel became connectable : but at times, the first finishConnect need - // not succeed : hence the loop to retry a few 'times'. - conn.finishConnect(true) - } catch { - case NonFatal(e) => { - logError("Error when finishConnect for " + conn.getRemoteConnectionManagerId(), e) - conn.callOnExceptionCallbacks(e) - } - } - } - } ) - } - - // MUST be called within selector loop - else deadlock. - private def triggerForceCloseByException(key: SelectionKey, e: Exception) { - try { - key.interestOps(0) - } catch { - // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) - } - - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - // Pushing to connect threadpool - handleConnectExecutor.execute(new Runnable { - override def run() { - try { - conn.callOnExceptionCallbacks(e) - } catch { - // ignore exceptions - case NonFatal(e) => logDebug("Ignoring exception", e) - } - try { - conn.close() - } catch { - // ignore exceptions - case NonFatal(e) => logDebug("Ignoring exception", e) - } - } - }) - } - - - def run() { - try { - while (isActive) { - while (!registerRequests.isEmpty) { - val conn: SendingConnection = registerRequests.dequeue() - addListeners(conn) - conn.connect() - addConnection(conn) - } - - while(!keyInterestChangeRequests.isEmpty) { - val (key, ops) = keyInterestChangeRequests.dequeue() - - try { - if (key.isValid) { - val connection = connectionsByKey.getOrElse(key, null) - if (connection != null) { - val lastOps = key.interestOps() - key.interestOps(ops) - - // hot loop - prevent materialization of string if trace not enabled. - if (isTraceEnabled()) { - def intToOpStr(op: Int): String = { - val opStrs = ArrayBuffer[String]() - if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" - if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" - if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" - if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" - if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " - } - - logTrace("Changed key for connection to [" + - connection.getRemoteConnectionManagerId() + "] changed from [" + - intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") - } - } - } else { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - - val selectedKeysCount = - try { - selector.select() - } catch { - // Explicitly only dealing with CancelledKeyException here since other exceptions - // should be dealt with differently. - case e: CancelledKeyException => - // Some keys within the selectors list are invalid/closed. clear them. - val allKeys = selector.keys().iterator() - - while (allKeys.hasNext) { - val key = allKeys.next() - try { - if (! key.isValid) { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - 0 - - case e: ClosedSelectorException => - logDebug("Failed select() as selector is closed.", e) - return - } - - if (selectedKeysCount == 0) { - logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + - " keys") - } - if (selectorThread.isInterrupted) { - logInfo("Selector thread was interrupted!") - return - } - - if (0 != selectedKeysCount) { - val selectedKeys = selector.selectedKeys().iterator() - while (selectedKeys.hasNext) { - val key = selectedKeys.next - selectedKeys.remove() - try { - if (key.isValid) { - if (key.isAcceptable) { - acceptConnection(key) - } else - if (key.isConnectable) { - triggerConnect(key) - } else - if (key.isReadable) { - triggerRead(key) - } else - if (key.isWritable) { - triggerWrite(key) - } - } else { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - // weird, but we saw this happening - even though key.isValid was true, - // key.isAcceptable would throw CancelledKeyException. - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - } - } - } catch { - case e: Exception => logError("Error in select loop", e) - } - } - - def acceptConnection(key: SelectionKey) { - val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] - - var newChannel = serverChannel.accept() - - // accept them all in a tight loop. non blocking accept with no processing, should be fine - while (newChannel != null) { - try { - val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId, - securityManager) - newConnection.onReceive(receiveMessage) - addListeners(newConnection) - addConnection(newConnection) - logInfo("Accepted connection from [" + newConnection.remoteAddress + "]") - } catch { - // might happen in case of issues with registering with selector - case e: Exception => logError("Error in accept loop", e) - } - - newChannel = serverChannel.accept() - } - } - - private def addListeners(connection: Connection) { - connection.onKeyInterestChange(changeConnectionKeyInterest) - connection.onException(handleConnectionError) - connection.onClose(removeConnection) - } - - def addConnection(connection: Connection) { - connectionsByKey += ((connection.key, connection)) - } - - def removeConnection(connection: Connection) { - connectionsByKey -= connection.key - - try { - connection match { - case sendingConnection: SendingConnection => - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - logInfo("Removing SendingConnection to " + sendingConnectionManagerId) - - connectionsById -= sendingConnectionManagerId - connectionsAwaitingSasl -= connection.connectionId - - messageStatuses.synchronized { - messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) - .foreach(status => { - logInfo("Notifying " + status) - status.failWithoutAck() - }) - - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - case receivingConnection: ReceivingConnection => - val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() - logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) - - val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) - if (!sendingConnectionOpt.isDefined) { - logError(s"Corresponding SendingConnection to ${remoteConnectionManagerId} not found") - return - } - - val sendingConnection = sendingConnectionOpt.get - connectionsById -= remoteConnectionManagerId - sendingConnection.close() - - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - - assert(sendingConnectionManagerId == remoteConnectionManagerId) - - messageStatuses.synchronized { - for (s <- messageStatuses.values - if s.connectionManagerId == sendingConnectionManagerId) { - logInfo("Notifying " + s) - s.failWithoutAck() - } - - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - case _ => logError("Unsupported type of connection.") - } - } finally { - // So that the selection keys can be removed. - wakeupSelector() - } - } - - def handleConnectionError(connection: Connection, e: Throwable) { - logInfo("Handling connection error on connection to " + - connection.getRemoteConnectionManagerId()) - removeConnection(connection) - } - - def changeConnectionKeyInterest(connection: Connection, ops: Int) { - keyInterestChangeRequests += ((connection.key, ops)) - // so that registrations happen ! - wakeupSelector() - } - - def receiveMessage(connection: Connection, message: Message) { - val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) - logDebug("Received [" + message + "] from [" + connectionManagerId + "]") - val runnable = new Runnable() { - val creationTime = System.currentTimeMillis - def run() { - try { - logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message, connection) - logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") - } catch { - case NonFatal(e) => { - logError("Error when handling messages from " + - connection.getRemoteConnectionManagerId(), e) - connection.callOnExceptionCallbacks(e) - } - } - } - } - handleMessageExecutor.execute(runnable) - /* handleMessage(connection, message) */ - } - - private def handleClientAuthentication( - waitingConn: SendingConnection, - securityMsg: SecurityMessage, - connectionId : ConnectionId) { - if (waitingConn.isSaslComplete()) { - logDebug("Client sasl completed for id: " + waitingConn.connectionId) - connectionsAwaitingSasl -= waitingConn.connectionId - waitingConn.registerAfterAuth() - wakeupSelector() - return - } else { - var replyToken : Array[Byte] = null - try { - replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken) - if (waitingConn.isSaslComplete()) { - logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) - connectionsAwaitingSasl -= waitingConn.connectionId - waitingConn.registerAfterAuth() - wakeupSelector() - return - } - val securityMsgResp = SecurityMessage.fromResponse(replyToken, - securityMsg.getConnectionId.toString) - val message = securityMsgResp.toBufferMessage - if (message == null) throw new IOException("Error creating security message") - sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) - } catch { - case e: Exception => - logError("Error handling sasl client authentication", e) - waitingConn.close() - throw new IOException("Error evaluating sasl response: ", e) - } - } - } - - private def handleServerAuthentication( - connection: Connection, - securityMsg: SecurityMessage, - connectionId: ConnectionId) { - if (!connection.isSaslComplete()) { - logDebug("saslContext not established") - var replyToken : Array[Byte] = null - try { - connection.synchronized { - if (connection.sparkSaslServer == null) { - logDebug("Creating sasl Server") - connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager, false) - } - } - replyToken = connection.sparkSaslServer.response(securityMsg.getToken) - if (connection.isSaslComplete()) { - logDebug("Server sasl completed: " + connection.connectionId + - " for: " + connectionId) - } else { - logDebug("Server sasl not completed: " + connection.connectionId + - " for: " + connectionId) - } - if (replyToken != null) { - val securityMsgResp = SecurityMessage.fromResponse(replyToken, - securityMsg.getConnectionId) - val message = securityMsgResp.toBufferMessage - if (message == null) throw new Exception("Error creating security Message") - sendSecurityMessage(connection.getRemoteConnectionManagerId(), message) - } - } catch { - case e: Exception => { - logError("Error in server auth negotiation: " + e) - // It would probably be better to send an error message telling other side auth failed - // but for now just close - connection.close() - } - } - } else { - logDebug("connection already established for this connection id: " + connection.connectionId) - } - } - - - private def handleAuthentication(conn: Connection, bufferMessage: BufferMessage): Boolean = { - if (bufferMessage.isSecurityNeg) { - logDebug("This is security neg message") - - // parse as SecurityMessage - val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage) - val connectionId = ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId) - - connectionsAwaitingSasl.get(connectionId) match { - case Some(waitingConn) => { - // Client - this must be in response to us doing Send - logDebug("Client handleAuth for id: " + waitingConn.connectionId) - handleClientAuthentication(waitingConn, securityMsg, connectionId) - } - case None => { - // Server - someone sent us something and we haven't authenticated yet - logDebug("Server handleAuth for id: " + connectionId) - handleServerAuthentication(conn, securityMsg, connectionId) - } - } - return true - } else { - if (!conn.isSaslComplete()) { - // We could handle this better and tell the client we need to do authentication - // negotiation, but for now just ignore them. - logError("message sent that is not security negotiation message on connection " + - "not authenticated yet, ignoring it!!") - return true - } - } - false - } - - private def handleMessage( - connectionManagerId: ConnectionManagerId, - message: Message, - connection: Connection) { - logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") - message match { - case bufferMessage: BufferMessage => { - if (authEnabled) { - val res = handleAuthentication(connection, bufferMessage) - if (res) { - // message was security negotiation so skip the rest - logDebug("After handleAuth result was true, returning") - return - } - } - if (bufferMessage.hasAckId()) { - messageStatuses.synchronized { - messageStatuses.get(bufferMessage.ackId) match { - case Some(status) => { - messageStatuses -= bufferMessage.ackId - status.success(message) - } - case None => { - /** - * We can fall down on this code because of following 2 cases - * - * (1) Invalid ack sent due to buggy code. - * - * (2) Late-arriving ack for a SendMessageStatus - * To avoid unwilling late-arriving ack - * caused by long pause like GC, you can set - * larger value than default to spark.core.connection.ack.wait.timeout - */ - logWarning(s"Could not find reference for received ack Message ${message.id}") - } - } - } - } else { - var ackMessage : Option[Message] = None - try { - ackMessage = if (onReceiveCallback != null) { - logDebug("Calling back") - onReceiveCallback(bufferMessage, connectionManagerId) - } else { - logDebug("Not calling back as callback is null") - None - } - - if (ackMessage.isDefined) { - if (!ackMessage.get.isInstanceOf[BufferMessage]) { - logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " - + ackMessage.get.getClass) - } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { - logDebug("Response to " + bufferMessage + " does not have ack id set") - ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id - } - } - } catch { - case e: Exception => { - logError(s"Exception was thrown while processing message", e) - ackMessage = Some(Message.createErrorMessage(e, bufferMessage.id)) - } - } finally { - sendMessage(connectionManagerId, ackMessage.getOrElse { - Message.createBufferMessage(bufferMessage.id) - }) - } - } - } - case _ => throw new Exception("Unknown type message received") - } - } - - private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: SendingConnection) { - // see if we need to do sasl before writing - // this should only be the first negotiation as the Client!!! - if (!conn.isSaslComplete()) { - conn.synchronized { - if (conn.sparkSaslClient == null) { - conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager, false) - var firstResponse: Array[Byte] = null - try { - firstResponse = conn.sparkSaslClient.firstToken() - val securityMsg = SecurityMessage.fromResponse(firstResponse, - conn.connectionId.toString()) - val message = securityMsg.toBufferMessage - if (message == null) throw new Exception("Error creating security message") - connectionsAwaitingSasl += ((conn.connectionId, conn)) - sendSecurityMessage(connManagerId, message) - logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId + - " to: " + connManagerId) - } catch { - case e: Exception => { - logError("Error getting first response from the SaslClient.", e) - conn.close() - throw new Exception("Error getting first response from the SaslClient") - } - } - } - } - } else { - logDebug("Sasl already established ") - } - } - - // allow us to add messages to the inbox for doing sasl negotiating - private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) { - def startNewConnection(): SendingConnection = { - val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port) - val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId, - newConnectionId, securityManager) - logInfo("creating new sending connection for security! " + newConnectionId ) - registerRequests.enqueue(newConnection) - - newConnection - } - // I removed the lookupKey stuff as part of merge ... should I re-add it ? - // We did not find it useful in our test-env ... - // If we do re-add it, we should consistently use it everywhere I guess ? - message.senderAddress = id.toSocketAddress() - logTrace("Sending Security [" + message + "] to [" + connManagerId + "]") - val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection()) - - // send security message until going connection has been authenticated - connection.send(message) - - wakeupSelector() - } - - private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { - def startNewConnection(): SendingConnection = { - val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, - connectionManagerId.port) - val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, - newConnectionId, securityManager) - newConnection.onException { - case (conn, e) => { - logError("Exception while sending message.", e) - reportSendingMessageFailure(message.id, e) - } - } - logTrace("creating new sending connection: " + newConnectionId) - registerRequests.enqueue(newConnection) - - newConnection - } - val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) - - message.senderAddress = id.toSocketAddress() - logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " + - "connectionid: " + connection.connectionId) - - if (authEnabled) { - try { - checkSendAuthFirst(connectionManagerId, connection) - } catch { - case NonFatal(e) => { - reportSendingMessageFailure(message.id, e) - } - } - } - logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") - connection.send(message) - wakeupSelector() - } - - private def reportSendingMessageFailure(messageId: Int, e: Throwable): Unit = { - // need to tell sender it failed - messageStatuses.synchronized { - val s = messageStatuses.get(messageId) - s match { - case Some(msgStatus) => { - messageStatuses -= messageId - logInfo("Notifying " + msgStatus.connectionManagerId) - msgStatus.failure(e) - } - case None => { - logError("no messageStatus for failed message id: " + messageId) - } - } - } - } - - private def wakeupSelector() { - selector.wakeup() - } - - /** - * Send a message and block until an acknowledgment is received or an error occurs. - * @param connectionManagerId the message's destination - * @param message the message being sent - * @return a Future that either returns the acknowledgment message or captures an exception. - */ - def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) - : Future[Message] = { - val promise = Promise[Message]() - - // It's important that the TimerTask doesn't capture a reference to `message`, which can cause - // memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time - // at which they would originally be scheduled to run. Therefore, extract the message id - // from outside of the TimerTask closure (see SPARK-4393 for more context). - val messageId = message.id - // Keep a weak reference to the promise so that the completed promise may be garbage-collected - val promiseReference = new WeakReference(promise) - val timeoutTask: TimerTask = new TimerTask { - override def run(timeout: Timeout): Unit = { - messageStatuses.synchronized { - messageStatuses.remove(messageId).foreach { s => - val e = new IOException("sendMessageReliably failed because ack " + - s"was not received within $ackTimeout sec") - val p = promiseReference.get - if (p != null) { - // Attempt to fail the promise with a Timeout exception - if (!p.tryFailure(e)) { - // If we reach here, then someone else has already signalled success or failure - // on this promise, so log a warning: - logError("Ignore error because promise is completed", e) - } - } else { - // The WeakReference was empty, which should never happen because - // sendMessageReliably's caller should have a strong reference to promise.future; - logError("Promise was garbage collected; this should never happen!", e) - } - } - } - } - } - - val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS) - - val status = new MessageStatus(message, connectionManagerId, s => { - timeoutTaskHandle.cancel() - s match { - case scala.util.Failure(e) => - // Indicates a failure where we either never sent or never got ACK'd - if (!promise.tryFailure(e)) { - logWarning("Ignore error because promise is completed", e) - } - case scala.util.Success(ackMessage) => - if (ackMessage.hasError) { - val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head - val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit()) - errorMsgByteBuf.get(errorMsgBytes) - val errorMsg = new String(errorMsgBytes, UTF_8) - val e = new IOException( - s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg") - if (!promise.tryFailure(e)) { - logWarning("Ignore error because promise is completed", e) - } - } else { - if (!promise.trySuccess(ackMessage)) { - logWarning("Drop ackMessage because promise is completed") - } - } - } - }) - messageStatuses.synchronized { - messageStatuses += ((message.id, status)) - } - - sendMessage(connectionManagerId, message) - promise.future - } - - def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { - onReceiveCallback = callback - } - - def stop() { - isActive = false - ackTimeoutMonitor.stop() - selector.close() - selectorThread.interrupt() - selectorThread.join() - val connections = connectionsByKey.values - connections.foreach(_.close()) - if (connectionsByKey.size != 0) { - logWarning("All connections not cleaned up") - } - handleMessageExecutor.shutdown() - handleReadWriteExecutor.shutdown() - handleConnectExecutor.shutdown() - logInfo("ConnectionManager stopped") - } -} - - -private[spark] object ConnectionManager { - import scala.concurrent.ExecutionContext.Implicits.global - - def main(args: Array[String]) { - val conf = new SparkConf - val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - // scalastyle:off println - println("Received [" + msg + "] from [" + id + "]") - // scalastyle:on println - None - }) - - /* testSequentialSending(manager) */ - /* System.gc() */ - - /* testParallelSending(manager) */ - /* System.gc() */ - - /* testParallelDecreasingSending(manager) */ - /* System.gc() */ - - testContinuousSending(manager) - System.gc() - } - - // scalastyle:off println - def testSequentialSending(manager: ConnectionManager) { - println("--------------------------") - println("Sequential Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf) - }) - println("--------------------------") - println() - } - - def testParallelSending(manager: ConnectionManager) { - println("--------------------------") - println("Parallel Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - f.onFailure { - case e => println("Failed due to " + e) - } - Await.ready(f, 1 second) - }) - val finishTime = System.currentTimeMillis - - val mb = size * count / 1024.0 / 1024.0 - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - println("Started at " + startTime + ", finished at " + finishTime) - println("Sent " + count + " messages of size " + size + " in " + ms + " ms " + - "(" + tput + " MB/s)") - println("--------------------------") - println() - } - - def testParallelDecreasingSending(manager: ConnectionManager) { - println("--------------------------") - println("Parallel Decreasing Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - val buffers = Array.tabulate(count) { i => - val bufferLen = size * (i + 1) - val bufferContent = Array.tabulate[Byte](bufferLen)(x => x.toByte) - ByteBuffer.allocate(bufferLen).put(bufferContent) - } - buffers.foreach(_.flip) - val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0 - - val startTime = System.currentTimeMillis - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - f.onFailure { - case e => println("Failed due to " + e) - } - Await.ready(f, 1 second) - }) - val finishTime = System.currentTimeMillis - - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - /* println("Started at " + startTime + ", finished at " + finishTime) */ - println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - - def testContinuousSending(manager: ConnectionManager) { - println("--------------------------") - println("Continuous Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - while(true) { - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - f.onFailure { - case e => println("Failed due to " + e) - } - Await.ready(f, 1 second) - }) - val finishTime = System.currentTimeMillis - Thread.sleep(1000) - val mb = size * count / 1024.0 / 1024.0 - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - } - // scalastyle:on println -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala deleted file mode 100644 index 85d2fe2bf9c20..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.net.InetSocketAddress -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -import com.google.common.base.Charsets.UTF_8 - -import org.apache.spark.util.Utils - -private[nio] abstract class Message(val typ: Long, val id: Int) { - var senderAddress: InetSocketAddress = null - var started = false - var startTime = -1L - var finishTime = -1L - var isSecurityNeg = false - var hasError = false - - def size: Int - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] - - def timeTaken(): String = (finishTime - startTime).toString + " ms" - - override def toString: String = { - this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" - } -} - - -private[nio] object Message { - val BUFFER_MESSAGE = 1111111111L - - var lastId = 1 - - def getNewId(): Int = synchronized { - lastId += 1 - if (lastId == 0) { - lastId += 1 - } - lastId - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = { - if (dataBuffers == null) { - return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId) - } - if (dataBuffers.exists(_ == null)) { - throw new Exception("Attempting to create buffer message with null buffer") - } - new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId) - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = - createBufferMessage(dataBuffers, 0) - - def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { - if (dataBuffer == null) { - createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) - } else { - createBufferMessage(Array(dataBuffer), ackId) - } - } - - def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = - createBufferMessage(dataBuffer, 0) - - def createBufferMessage(ackId: Int): BufferMessage = { - createBufferMessage(new Array[ByteBuffer](0), ackId) - } - - /** - * Create a "negative acknowledgment" to notify a sender that an error occurred - * while processing its message. The exception's stacktrace will be formatted - * as a string, serialized into a byte array, and sent as the message payload. - */ - def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = { - val exceptionString = Utils.exceptionString(exception) - val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes(UTF_8)) - val errorMessage = createBufferMessage(serializedExceptionString, ackId) - errorMessage.hasError = true - errorMessage - } - - def create(header: MessageChunkHeader): Message = { - val newMessage: Message = header.typ match { - case BUFFER_MESSAGE => new BufferMessage(header.id, - ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) - } - newMessage.hasError = header.hasError - newMessage.senderAddress = header.address - newMessage - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala deleted file mode 100644 index 7b3da4bb9d5ee..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.net.{InetAddress, InetSocketAddress} -import java.nio.ByteBuffer - -private[nio] class MessageChunkHeader( - val typ: Long, - val id: Int, - val totalSize: Int, - val chunkSize: Int, - val other: Int, - val hasError: Boolean, - val securityNeg: Int, - val address: InetSocketAddress) { - lazy val buffer = { - // No need to change this, at 'use' time, we do a reverse lookup of the hostname. - // Refer to network.Connection - val ip = address.getAddress.getAddress() - val port = address.getPort() - ByteBuffer. - allocate(MessageChunkHeader.HEADER_SIZE). - putLong(typ). - putInt(id). - putInt(totalSize). - putInt(chunkSize). - putInt(other). - put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]). - putInt(securityNeg). - putInt(ip.size). - put(ip). - putInt(port). - position(MessageChunkHeader.HEADER_SIZE). - flip.asInstanceOf[ByteBuffer] - } - - override def toString: String = { - "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg - } - -} - - -private[nio] object MessageChunkHeader { - val HEADER_SIZE = 45 - - def create(buffer: ByteBuffer): MessageChunkHeader = { - if (buffer.remaining != HEADER_SIZE) { - throw new IllegalArgumentException("Cannot convert buffer data to Message") - } - val typ = buffer.getLong() - val id = buffer.getInt() - val totalSize = buffer.getInt() - val chunkSize = buffer.getInt() - val other = buffer.getInt() - val hasError = buffer.get() != 0 - val securityNeg = buffer.getInt() - val ipSize = buffer.getInt() - val ipBytes = new Array[Byte](ipSize) - buffer.get(ipBytes) - val ip = InetAddress.getByAddress(ipBytes) - val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg, - new InetSocketAddress(ip, port)) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala deleted file mode 100644 index b2aec160635c7..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import org.apache.spark.network._ -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.storage.{BlockId, StorageLevel} -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} - -import scala.concurrent.Future - - -/** - * A [[BlockTransferService]] implementation based on [[ConnectionManager]], a custom - * implementation using Java NIO. - */ -final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityManager) - extends BlockTransferService with Logging { - - private var cm: ConnectionManager = _ - - private var blockDataManager: BlockDataManager = _ - - /** - * Port number the service is listening on, available only after [[init]] is invoked. - */ - override def port: Int = { - checkInit() - cm.id.port - } - - /** - * Host name the service is listening on, available only after [[init]] is invoked. - */ - override def hostName: String = { - checkInit() - cm.id.host - } - - /** - * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch - * local blocks or put local blocks. - */ - override def init(blockDataManager: BlockDataManager): Unit = { - this.blockDataManager = blockDataManager - cm = new ConnectionManager( - conf.getInt("spark.blockManager.port", 0), - conf, - securityManager, - "Connection manager for block manager") - cm.onReceiveMessage(onBlockMessageReceive) - } - - /** - * Tear down the transfer service. - */ - override def close(): Unit = { - if (cm != null) { - cm.stop() - } - } - - override def fetchBlocks( - host: String, - port: Int, - execId: String, - blockIds: Array[String], - listener: BlockFetchingListener): Unit = { - checkInit() - - val cmId = new ConnectionManagerId(host, port) - val blockMessageArray = new BlockMessageArray(blockIds.map { blockId => - BlockMessage.fromGetBlock(GetBlock(BlockId(blockId))) - }) - - val future = cm.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - - // Register the listener on success/failure future callback. - future.onSuccess { case message => - val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - - // SPARK-4064: In some cases(eg. Remote block was removed) blockMessageArray may be empty. - if (blockMessageArray.isEmpty) { - blockIds.foreach { id => - listener.onBlockFetchFailure(id, new SparkException(s"Received empty message from $cmId")) - } - } else { - for (blockMessage: BlockMessage <- blockMessageArray) { - val msgType = blockMessage.getType - if (msgType != BlockMessage.TYPE_GOT_BLOCK) { - if (blockMessage.getId != null) { - listener.onBlockFetchFailure(blockMessage.getId.toString, - new SparkException(s"Unexpected message $msgType received from $cmId")) - } - } else { - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - listener.onBlockFetchSuccess( - blockId.toString, new NioManagedBuffer(blockMessage.getData)) - } - } - } - }(cm.futureExecContext) - - future.onFailure { case exception => - blockIds.foreach { blockId => - listener.onBlockFetchFailure(blockId, exception) - } - }(cm.futureExecContext) - } - - /** - * Upload a single block to a remote node, available only after [[init]] is invoked. - * - * This call blocks until the upload completes, or throws an exception upon failures. - */ - override def uploadBlock( - hostname: String, - port: Int, - execId: String, - blockId: BlockId, - blockData: ManagedBuffer, - level: StorageLevel) - : Future[Unit] = { - checkInit() - val msg = PutBlock(blockId, blockData.nioByteBuffer(), level) - val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg)) - val remoteCmId = new ConnectionManagerId(hostName, port) - val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage) - reply.map(x => ())(cm.futureExecContext) - } - - private def checkInit(): Unit = if (cm == null) { - throw new IllegalStateException(getClass.getName + " has not been initialized") - } - - private def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { - logDebug("Handling message " + msg) - msg match { - case bufferMessage: BufferMessage => - try { - logDebug("Handling as a buffer message " + bufferMessage) - val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) - logDebug("Parsed as a block message array") - val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) - Some(new BlockMessageArray(responseMessages).toBufferMessage) - } catch { - case e: Exception => - logError("Exception handling buffer message", e) - Some(Message.createErrorMessage(e, msg.id)) - } - - case otherMessage: Any => - val errorMsg = s"Received unknown message type: ${otherMessage.getClass.getName}" - logError(errorMsg) - Some(Message.createErrorMessage(new UnsupportedOperationException(errorMsg), msg.id)) - } - } - - private def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => - val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - logDebug("Received [" + msg + "]") - putBlock(msg.id, msg.data, msg.level) - None - - case BlockMessage.TYPE_GET_BLOCK => - val msg = new GetBlock(blockMessage.getId) - logDebug("Received [" + msg + "]") - val buffer = getBlock(msg.id) - if (buffer == null) { - return None - } - Some(BlockMessage.fromGotBlock(GotBlock(msg.id, buffer))) - - case _ => None - } - } - - private def putBlock(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) { - val startTimeMs = System.currentTimeMillis() - logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) - blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level) - logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) - + " with data size: " + bytes.limit) - } - - private def getBlock(blockId: BlockId): ByteBuffer = { - val startTimeMs = System.currentTimeMillis() - logDebug("GetBlock " + blockId + " started from " + startTimeMs) - val buffer = blockDataManager.getBlockData(blockId) - logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) - + " and got buffer " + buffer) - buffer.nioByteBuffer() - } -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala deleted file mode 100644 index 232c552f9865d..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.nio - -import java.nio.ByteBuffer - -import scala.collection.mutable.{ArrayBuffer, StringBuilder} - -import org.apache.spark._ - -/** - * SecurityMessage is class that contains the connectionId and sasl token - * used in SASL negotiation. SecurityMessage has routines for converting - * it to and from a BufferMessage so that it can be sent by the ConnectionManager - * and easily consumed by users when received. - * The api was modeled after BlockMessage. - * - * The connectionId is the connectionId of the client side. Since - * message passing is asynchronous and its possible for the server side (receiving) - * to get multiple different types of messages on the same connection the connectionId - * is used to know which connnection the security message is intended for. - * - * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side - * is acting as a client and connecting to node_1. SASL negotiation has to occur - * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message. - * node_1 receives the message from node_0 but before it can process it and send a response, - * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0 - * and sends a security message of its own to authenticate as a client. Now node_0 gets - * the message and it needs to decide if this message is in response to it being a client - * (from the first send) or if its just node_1 trying to connect to it to send data. This - * is where the connectionId field is used. node_0 can lookup the connectionId to see if - * it is in response to it being a client or if its in response to someone sending other data. - * - * The format of a SecurityMessage as its sent is: - * - Length of the ConnectionId - * - ConnectionId - * - Length of the token - * - Token - */ -private[nio] class SecurityMessage extends Logging { - - private var connectionId: String = null - private var token: Array[Byte] = null - - def set(byteArr: Array[Byte], newconnectionId: String) { - if (byteArr == null) { - token = new Array[Byte](0) - } else { - token = byteArr - } - connectionId = newconnectionId - } - - /** - * Read the given buffer and set the members of this class. - */ - def set(buffer: ByteBuffer) { - val idLength = buffer.getInt() - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buffer.getChar() - } - connectionId = idBuilder.toString() - - val tokenLength = buffer.getInt() - token = new Array[Byte](tokenLength) - if (tokenLength > 0) { - buffer.get(token, 0, tokenLength) - } - } - - def set(bufferMsg: BufferMessage) { - val buffer = bufferMsg.buffers.apply(0) - buffer.clear() - set(buffer) - } - - def getConnectionId: String = { - return connectionId - } - - def getToken: Array[Byte] = { - return token - } - - /** - * Create a BufferMessage that can be sent by the ConnectionManager containing - * the security information from this class. - * @return BufferMessage - */ - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - - // 4 bytes for the length of the connectionId - // connectionId is of type char so multiple the length by 2 to get number of bytes - // 4 bytes for the length of token - // token is a byte buffer so just take the length - var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length) - buffer.putInt(connectionId.length()) - connectionId.foreach((x: Char) => buffer.putChar(x)) - buffer.putInt(token.length) - - if (token.length > 0) { - buffer.put(token) - } - buffer.flip() - buffers += buffer - - var message = Message.createBufferMessage(buffers) - logDebug("message total size is : " + message.size) - message.isSecurityNeg = true - return message - } - - override def toString: String = { - "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]" - } -} - -private[nio] object SecurityMessage { - - /** - * Convert the given BufferMessage to a SecurityMessage by parsing the contents - * of the BufferMessage and populating the SecurityMessage fields. - * @param bufferMessage is a BufferMessage that was received - * @return new SecurityMessage - */ - def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = { - val newSecurityMessage = new SecurityMessage() - newSecurityMessage.set(bufferMessage) - newSecurityMessage - } - - /** - * Create a SecurityMessage to send from a given saslResponse. - * @param response is the response to a challenge from the SaslClient or Saslserver - * @param connectionId the client connectionId we are negotiation authentication for - * @return a new SecurityMessage - */ - def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = { - val newSecurityMessage = new SecurityMessage() - newSecurityMessage.set(response, connectionId) - newSecurityMessage - } -} diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 8ae76c5f72f2e..7515aad09db73 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -43,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.5.0-SNAPSHOT" + val SPARK_VERSION = "1.6.0-SNAPSHOT" } diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 1f755db485812..aedced7408cde 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -28,12 +28,13 @@ private[spark] class BinaryFileRDD[T]( inputFormatClass: Class[_ <: StreamFileInputFormat[T]], keyClass: Class[String], valueClass: Class[T], - @transient conf: Configuration, + conf: Configuration, minPartitions: Int) extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) { override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance + val conf = getConf inputFormat match { case configurable: Configurable => configurable.setConf(conf) diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 922030263756b..fc1710fbad0a3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -28,7 +28,7 @@ private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends P } private[spark] -class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds: Array[BlockId]) +class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { @transient lazy val _locations = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) @@ -64,7 +64,7 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds */ private[spark] def removeBlocks() { blockIds.foreach { blockId => - sc.env.blockManager.master.removeBlock(blockId) + sparkContext.env.blockManager.master.removeBlock(blockId) } _isValid = false } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index c1d6971787572..18e8cddbc40db 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -27,8 +27,8 @@ import org.apache.spark.util.Utils private[spark] class CartesianPartition( idx: Int, - @transient rdd1: RDD[_], - @transient rdd2: RDD[_], + @transient private val rdd1: RDD[_], + @transient private val rdd2: RDD[_], s1Index: Int, s2Index: Int ) extends Partition { diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 72fe215dae73e..b0364623af4cf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -29,7 +29,7 @@ private[spark] class CheckpointRDDPartition(val index: Int) extends Partition /** * An RDD that recovers checkpointed data from storage. */ -private[spark] abstract class CheckpointRDD[T: ClassTag](@transient sc: SparkContext) +private[spark] abstract class CheckpointRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) { // CheckpointRDD should not be checkpointed again diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 9c617fc719cb5..935c3babd8ea1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -22,10 +22,11 @@ import scala.language.existentials import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} +import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap} import org.apache.spark.util.Utils import org.apache.spark.serializer.Serializer @@ -74,7 +75,9 @@ private[spark] class CoGroupPartition( * @param part partitioner used to partition the shuffle output */ @DeveloperApi -class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) +class CoGroupedRDD[K: ClassTag]( + @transient var rdds: Seq[RDD[_ <: Product2[K, _]]], + part: Partitioner) extends RDD[(K, Array[Iterable[_]])](rdds.head.context, Nil) { // For example, `(k, a) cogroup (k, b)` produces k -> Array(ArrayBuffer as, ArrayBuffer bs). @@ -125,8 +128,6 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: override val partitioner: Some[Partitioner] = Some(part) override def compute(s: Partition, context: TaskContext): Iterator[(K, Array[Iterable[_]])] = { - val sparkConf = SparkEnv.get.conf - val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] val numRdds = dependencies.length @@ -147,34 +148,16 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: rddIterators += ((it, depNum)) } - if (!externalSorting) { - val map = new AppendOnlyMap[K, CoGroupCombiner] - val update: (Boolean, CoGroupCombiner) => CoGroupCombiner = (hadVal, oldVal) => { - if (hadVal) oldVal else Array.fill(numRdds)(new CoGroup) - } - val getCombiner: K => CoGroupCombiner = key => { - map.changeValue(key, update) - } - rddIterators.foreach { case (it, depNum) => - while (it.hasNext) { - val kv = it.next() - getCombiner(kv._1)(depNum) += kv._2 - } - } - new InterruptibleIterator(context, - map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) - } else { - val map = createExternalMap(numRdds) - for ((it, depNum) <- rddIterators) { - map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) - } - context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) - context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) - new InterruptibleIterator(context, - map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) + val map = createExternalMap(numRdds) + for ((it, depNum) <- rddIterators) { + map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } + context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + new InterruptibleIterator(context, + map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } private def createExternalMap(numRdds: Int) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e1f8719eead02..77b57132b9f1f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -51,7 +51,7 @@ import org.apache.spark.storage.StorageLevel /** * A Spark split class that wraps around a Hadoop InputSplit. */ -private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit) +private[spark] class HadoopPartition(rddId: Int, idx: Int, s: InputSplit) extends Partition { val inputSplit = new SerializableWritable[InputSplit](s) @@ -99,7 +99,7 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp */ @DeveloperApi class HadoopRDD[K, V]( - @transient sc: SparkContext, + sc: SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], @@ -109,7 +109,7 @@ class HadoopRDD[K, V]( extends RDD[(K, V)](sc, Nil) with Logging { if (initLocalJobConfFuncOpt.isDefined) { - sc.clean(initLocalJobConfFuncOpt.get) + sparkContext.clean(initLocalJobConfFuncOpt.get) } def this( @@ -137,7 +137,7 @@ class HadoopRDD[K, V]( // used to build JobTracker ID private val createTime = new Date() - private val shouldCloneJobConf = sc.conf.getBoolean("spark.hadoop.cloneConf", false) + private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false) // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { @@ -182,17 +182,11 @@ class HadoopRDD[K, V]( } protected def getInputFormat(conf: JobConf): InputFormat[K, V] = { - if (HadoopRDD.containsCachedMetadata(inputFormatCacheKey)) { - return HadoopRDD.getCachedMetadata(inputFormatCacheKey).asInstanceOf[InputFormat[K, V]] - } - // Once an InputFormat for this RDD is created, cache it so that only one reflection call is - // done in each local process. val newInputFormat = ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf) .asInstanceOf[InputFormat[K, V]] if (newInputFormat.isInstanceOf[Configurable]) { newInputFormat.asInstanceOf[Configurable].setConf(conf) } - HadoopRDD.putCachedMetadata(inputFormatCacheKey, newInputFormat) newInputFormat } diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala index daa5779d688cc..bfe19195fcd37 100644 --- a/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/LocalCheckpointRDD.scala @@ -35,7 +35,7 @@ import org.apache.spark.storage.RDDBlockId * @param numPartitions the number of partitions in the checkpointed RDD */ private[spark] class LocalCheckpointRDD[T: ClassTag]( - @transient sc: SparkContext, + sc: SparkContext, rddId: Int, numPartitions: Int) extends CheckpointRDD[T](sc) { diff --git a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala index d6fad896845f6..c115e0ff74d3c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/LocalRDDCheckpointData.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.Utils * is written to the local, ephemeral block storage that lives in each executor. This is useful * for use cases where RDDs build up long lineages that need to be truncated often (e.g. GraphX). */ -private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) +private[spark] class LocalRDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) extends RDDCheckpointData[T](rdd) with Logging { /** diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala index 1f2213d0c4346..417ff5278db2a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala @@ -41,7 +41,7 @@ private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M // In certain join operations, prepare can be called on the same partition multiple times. // In this case, we need to ensure that each call to compute gets a separate prepare argument. - private[this] var preparedArguments: ArrayBuffer[M] = new ArrayBuffer[M] + private[this] val preparedArguments: ArrayBuffer[M] = new ArrayBuffer[M] /** * Prepare a partition for a single call to compute. diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 6a9c004d65cff..2872b93b8730e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -40,11 +40,10 @@ import org.apache.spark.storage.StorageLevel private[spark] class NewHadoopPartition( rddId: Int, val index: Int, - @transient rawSplit: InputSplit with Writable) + rawSplit: InputSplit with Writable) extends Partition { val serializableHadoopSplit = new SerializableWritable(rawSplit) - override def hashCode(): Int = 41 * (41 + rddId) + index } @@ -68,14 +67,14 @@ class NewHadoopRDD[K, V]( inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - @transient conf: Configuration) + @transient private val _conf: Configuration) extends RDD[(K, V)](sc, Nil) with SparkHadoopMapReduceUtil with Logging { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = sc.broadcast(new SerializableConfiguration(conf)) - // private val serializableConf = new SerializableWritable(conf) + private val confBroadcast = sc.broadcast(new SerializableConfiguration(_conf)) + // private val serializableConf = new SerializableWritable(_conf) private val jobTrackerId: String = { val formatter = new SimpleDateFormat("yyyyMMddHHmm") @@ -84,14 +83,35 @@ class NewHadoopRDD[K, V]( @transient protected val jobId = new JobID(jobTrackerId, id) + private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false) + + def getConf: Configuration = { + val conf: Configuration = confBroadcast.value.value + if (shouldCloneJobConf) { + // Hadoop Configuration objects are not thread-safe, which may lead to various problems if + // one job modifies a configuration while another reads it (SPARK-2546, SPARK-10611). This + // problem occurs somewhat rarely because most jobs treat the configuration as though it's + // immutable. One solution, implemented here, is to clone the Configuration object. + // Unfortunately, this clone can be very expensive. To avoid unexpected performance + // regressions for workloads and Hadoop versions that do not suffer from these thread-safety + // issues, this cloning is disabled by default. + NewHadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Cloning Hadoop Configuration") + new Configuration(conf) + } + } else { + conf + } + } + override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance inputFormat match { case configurable: Configurable => - configurable.setConf(conf) + configurable.setConf(_conf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = newJobContext(_conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -104,7 +124,7 @@ class NewHadoopRDD[K, V]( val iter = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) - val conf = confBroadcast.value.value + val conf = getConf val inputMetrics = context.taskMetrics .getInputMetricsForReadMethod(DataReadMethod.Hadoop) @@ -230,11 +250,15 @@ class NewHadoopRDD[K, V]( super.persist(storageLevel) } - - def getConf: Configuration = confBroadcast.value.value } private[spark] object NewHadoopRDD { + /** + * Configuration's constructor is not threadsafe (see SPARK-1097 and HADOOP-10456). + * Therefore, we synchronize on this lock before calling new Configuration(). + */ + val CONFIGURATION_INSTANTIATION_LOCK = new Object() + /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to * the given function rather than the index of the partition. @@ -262,12 +286,13 @@ private[spark] class WholeTextFileRDD( inputFormatClass: Class[_ <: WholeTextFileInputFormat], keyClass: Class[String], valueClass: Class[String], - @transient conf: Configuration, + conf: Configuration, minPartitions: Int) extends NewHadoopRDD[String, String](sc, inputFormatClass, keyClass, valueClass, conf) { override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance + val conf = getConf inputFormat match { case configurable: Configurable => configurable.setConf(conf) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 4e5f2e8a5d467..a981b63942e6d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -57,7 +57,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) with SparkHadoopMapReduceUtil with Serializable { + /** + * :: Experimental :: * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C * Note that V and C can be different -- for example, one might group an RDD of type @@ -70,12 +72,14 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). */ - def combineByKey[C](createCombiner: V => C, + @Experimental + def combineByKeyWithClassTag[C]( + createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true, - serializer: Serializer = null): RDD[(K, C)] = self.withScope { + serializer: Serializer = null)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope { require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0 if (keyClass.isArray) { if (mapSideCombine) { @@ -103,13 +107,50 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * Simplified version of combineByKey that hash-partitions the output RDD. + * Generic function to combine the elements for each key using a custom set of aggregation + * functions. This method is here for backward compatibility. It does not provide combiner + * classtag information to the shuffle. + * + * @see [[combineByKeyWithClassTag]] */ - def combineByKey[C](createCombiner: V => C, + def combineByKey[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + partitioner: Partitioner, + mapSideCombine: Boolean = true, + serializer: Serializer = null): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, + partitioner, mapSideCombine, serializer)(null) + } + + /** + * Simplified version of combineByKeyWithClassTag that hash-partitions the output RDD. + * This method is here for backward compatibility. It does not provide combiner + * classtag information to the shuffle. + * + * @see [[combineByKeyWithClassTag]] + */ + def combineByKey[C]( + createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, numPartitions: Int): RDD[(K, C)] = self.withScope { - combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, numPartitions)(null) + } + + /** + * :: Experimental :: + * Simplified version of combineByKeyWithClassTag that hash-partitions the output RDD. + */ + @Experimental + def combineByKeyWithClassTag[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + numPartitions: Int)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, + new HashPartitioner(numPartitions)) } /** @@ -133,7 +174,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // We will clean the combiner closure later in `combineByKey` val cleanedSeqOp = self.context.clean(seqOp) - combineByKey[U]((v: V) => cleanedSeqOp(createZero(), v), cleanedSeqOp, combOp, partitioner) + combineByKeyWithClassTag[U]((v: V) => cleanedSeqOp(createZero(), v), + cleanedSeqOp, combOp, partitioner) } /** @@ -182,7 +224,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val createZero = () => cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) val cleanedFunc = self.context.clean(func) - combineByKey[V]((v: V) => cleanedFunc(createZero(), v), cleanedFunc, cleanedFunc, partitioner) + combineByKeyWithClassTag[V]((v: V) => cleanedFunc(createZero(), v), + cleanedFunc, cleanedFunc, partitioner) } /** @@ -268,7 +311,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * "combiner" in MapReduce. */ def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope { - combineByKey[V]((v: V) => v, func, func, partitioner) + combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner) } /** @@ -392,7 +435,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) h1 } - combineByKey(createHLL, mergeValueHLL, mergeHLL, partitioner).mapValues(_.cardinality()) + combineByKeyWithClassTag(createHLL, mergeValueHLL, mergeHLL, partitioner) + .mapValues(_.cardinality()) } /** @@ -466,7 +510,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val createCombiner = (v: V) => CompactBuffer(v) val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2 - val bufs = combineByKey[CompactBuffer[V]]( + val bufs = combineByKeyWithClassTag[CompactBuffer[V]]( createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) bufs.asInstanceOf[RDD[(K, Iterable[V])]] } @@ -565,12 +609,30 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } /** - * Simplified version of combineByKey that hash-partitions the resulting RDD using the + * Simplified version of combineByKeyWithClassTag that hash-partitions the resulting RDD using the + * existing partitioner/parallelism level. This method is here for backward compatibility. It + * does not provide combiner classtag information to the shuffle. + * + * @see [[combineByKeyWithClassTag]] + */ + def combineByKey[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners)(null) + } + + /** + * :: Experimental :: + * Simplified version of combineByKeyWithClassTag that hash-partitions the resulting RDD using the * existing partitioner/parallelism level. */ - def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) - : RDD[(K, C)] = self.withScope { - combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) + @Experimental + def combineByKeyWithClassTag[C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope { + combineByKeyWithClassTag(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) } /** @@ -934,8 +996,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) - job.getConfiguration.set("mapred.output.dir", path) - saveAsNewAPIHadoopDataset(job.getConfiguration) + val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + jobConfiguration.set("mapred.output.dir", path) + saveAsNewAPIHadoopDataset(jobConfiguration) } /** @@ -955,6 +1018,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. + * + * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * not use output committer that writes data directly. + * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad + * result of using direct output committer with speculation enabled. */ def saveAsHadoopFile( path: String, @@ -967,10 +1035,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val hadoopConf = conf hadoopConf.setOutputKeyClass(keyClass) hadoopConf.setOutputValueClass(valueClass) - // Doesn't work in Scala 2.9 due to what may be a generics bug - // TODO: Should we uncomment this for Scala 2.10? - // conf.setOutputFormat(outputFormatClass) - hadoopConf.set("mapred.output.format.class", outputFormatClass.getName) + conf.setOutputFormat(outputFormatClass) for (c <- codec) { hadoopConf.setCompressMapOutput(true) hadoopConf.set("mapred.output.compress", "true") @@ -984,6 +1049,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) hadoopConf.setOutputCommitter(classOf[FileOutputCommitter]) } + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = self.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = hadoopConf.get("mapred.output.committer.class", "") + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + FileOutputFormat.setOutputPath(hadoopConf, SparkHadoopWriter.createPathFromString(path, hadoopConf)) saveAsHadoopDataset(hadoopConf) @@ -994,6 +1072,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Configuration object for that storage system. The Conf should set an OutputFormat and any * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. + * + * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * not use output committer that writes data directly. + * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad + * result of using direct output committer with speculation enabled. */ def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). @@ -1002,7 +1085,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val wrappedConf = new SerializableConfiguration(job.getConfiguration) + val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val wrappedConf = new SerializableConfiguration(jobConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance @@ -1051,6 +1135,20 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) + + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = self.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = jobCommitter.getClass.getSimpleName + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + jobCommitter.setupJob(jobTaskContext) self.context.runJob(self, writeShard) jobCommitter.commitJob(jobTaskContext) @@ -1065,7 +1163,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val wrappedConf = new SerializableConfiguration(hadoopConf) val outputFormatInstance = hadoopConf.getOutputFormat val keyClass = hadoopConf.getOutputKeyClass val valueClass = hadoopConf.getOutputValueClass @@ -1093,7 +1190,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.preSetup() val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { - val config = wrappedConf.value // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index e2394e28f8d26..582fa93afe34e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -83,8 +83,8 @@ private[spark] class ParallelCollectionPartition[T: ClassTag]( } private[spark] class ParallelCollectionRDD[T: ClassTag]( - @transient sc: SparkContext, - @transient data: Seq[T], + sc: SparkContext, + @transient private val data: Seq[T], numSlices: Int, locationPrefs: Map[Int, Seq[String]]) extends RDD[T](sc, Nil) { diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index a00f4c1cdff91..d6a37e8cc5dac 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -32,7 +32,7 @@ private[spark] class PartitionPruningRDDPartition(idx: Int, val parentSplit: Par * Represents a dependency between the PartitionPruningRDD and its parent. In this * case, the child RDD contains a subset of partitions of the parents'. */ -private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) +private[spark] class PruneDependency[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) extends NarrowDependency[T](rdd) { @transient @@ -55,8 +55,8 @@ private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterF */ @DeveloperApi class PartitionPruningRDD[T: ClassTag]( - @transient prev: RDD[T], - @transient partitionFilterFunc: Int => Boolean) + prev: RDD[T], + partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { override def compute(split: Partition, context: TaskContext): Iterator[T] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index a637d6f15b7e5..3b1acacf409b9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -47,8 +47,8 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag]( prev: RDD[T], sampler: RandomSampler[T, U], - @transient preservesPartitioning: Boolean, - @transient seed: Long = Utils.random.nextLong) + preservesPartitioning: Boolean, + @transient private val seed: Long = Utils.random.nextLong) extends RDD[U](prev) { @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 081c721f23687..a97bb174438a5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -294,7 +294,11 @@ abstract class RDD[T: ClassTag]( */ private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = { - if (isCheckpointed) firstParent[T].iterator(split, context) else compute(split, context) + if (isCheckpointedAndMaterialized) { + firstParent[T].iterator(split, context) + } else { + compute(split, context) + } } /** @@ -469,50 +473,44 @@ abstract class RDD[T: ClassTag]( * @param seed seed for the random number generator * @return sample of specified size in an array */ - // TODO: rewrite this without return statements so we can wrap it in a scope def takeSample( withReplacement: Boolean, num: Int, - seed: Long = Utils.random.nextLong): Array[T] = { + seed: Long = Utils.random.nextLong): Array[T] = withScope { val numStDev = 10.0 - if (num < 0) { - throw new IllegalArgumentException("Negative number of elements requested") - } else if (num == 0) { - return new Array[T](0) - } - - val initialCount = this.count() - if (initialCount == 0) { - return new Array[T](0) - } - - val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt - if (num > maxSampleSize) { - throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " + - s"$numStDev * math.sqrt(Int.MaxValue)") - } - - val rand = new Random(seed) - if (!withReplacement && num >= initialCount) { - return Utils.randomizeInPlace(this.collect(), rand) - } - - val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, - withReplacement) - - var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + require(num >= 0, "Negative number of elements requested") + require(num <= (Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt), + "Cannot support a sample size > Int.MaxValue - " + + s"$numStDev * math.sqrt(Int.MaxValue)") - // If the first sample didn't turn out large enough, keep trying to take samples; - // this shouldn't happen often because we use a big multiplier for the initial size - var numIters = 0 - while (samples.length < num) { - logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") - samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() - numIters += 1 + if (num == 0) { + new Array[T](0) + } else { + val initialCount = this.count() + if (initialCount == 0) { + new Array[T](0) + } else { + val rand = new Random(seed) + if (!withReplacement && num >= initialCount) { + Utils.randomizeInPlace(this.collect(), rand) + } else { + val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, + withReplacement) + var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + + // If the first sample didn't turn out large enough, keep trying to take samples; + // this shouldn't happen often because we use a big multiplier for the initial size + var numIters = 0 + while (samples.length < num) { + logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") + samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + numIters += 1 + } + Utils.randomizeInPlace(samples, rand).take(num) + } + } } - - Utils.randomizeInPlace(samples, rand).take(num) } /** @@ -1526,20 +1524,37 @@ abstract class RDD[T: ClassTag]( persist(LocalRDDCheckpointData.transformStorageLevel(storageLevel), allowOverride = true) } - checkpointData match { - case Some(reliable: ReliableRDDCheckpointData[_]) => logWarning( - "RDD was already marked for reliable checkpointing: overriding with local checkpoint.") - case _ => + // If this RDD is already checkpointed and materialized, its lineage is already truncated. + // We must not override our `checkpointData` in this case because it is needed to recover + // the checkpointed data. If it is overridden, next time materializing on this RDD will + // cause error. + if (isCheckpointedAndMaterialized) { + logWarning("Not marking RDD for local checkpoint because it was already " + + "checkpointed and materialized") + } else { + // Lineage is not truncated yet, so just override any existing checkpoint data with ours + checkpointData match { + case Some(_: ReliableRDDCheckpointData[_]) => logWarning( + "RDD was already marked for reliable checkpointing: overriding with local checkpoint.") + case _ => + } + checkpointData = Some(new LocalRDDCheckpointData(this)) } - checkpointData = Some(new LocalRDDCheckpointData(this)) this } /** - * Return whether this RDD is marked for checkpointing, either reliably or locally. + * Return whether this RDD is checkpointed and materialized, either reliably or locally. */ def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed) + /** + * Return whether this RDD is checkpointed and materialized, either reliably or locally. + * This is introduced as an alias for `isCheckpointed` to clarify the semantics of the + * return value. Exposed for testing. + */ + private[spark] def isCheckpointedAndMaterialized: Boolean = isCheckpointed + /** * Return whether this RDD is marked for local checkpointing. * Exposed for testing. @@ -1666,7 +1681,7 @@ abstract class RDD[T: ClassTag]( import Utils.bytesToString val persistence = if (storageLevel != StorageLevel.NONE) storageLevel.description else "" - val storageInfo = rdd.context.getRDDStorageInfo.filter(_.id == rdd.id).map(info => + val storageInfo = rdd.context.getRDDStorageInfo(_.id == rdd.id).map(info => " CachedPartitions: %d; MemorySize: %s; ExternalBlockStoreSize: %s; DiskSize: %s".format( info.numCachedPartitions, bytesToString(info.memSize), bytesToString(info.externalBlockStoreSize), bytesToString(info.diskSize))) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 0e43520870c0a..429514b4f6bee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -36,7 +36,7 @@ private[spark] object CheckpointState extends Enumeration { * as well as, manages the post-checkpoint state by providing the updated partitions, * iterator and preferred locations of the checkpointed RDD. */ -private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) +private[spark] abstract class RDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) extends Serializable { import CheckpointState._ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala index 44667281c1063..540cbd688b63b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.{JsonIgnore, JsonInclude, JsonPropertyOr import com.fasterxml.jackson.annotation.JsonInclude.Include import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule +import com.google.common.base.Objects import org.apache.spark.{Logging, SparkContext} @@ -67,6 +68,8 @@ private[spark] class RDDOperationScope( } } + override def hashCode(): Int = Objects.hashCode(id, name, parent) + override def toString: String = toJson } diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index 35d8b0bfd18c5..a69be6a068bbf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} * An RDD that reads from checkpoint files previously written to reliable storage. */ private[spark] class ReliableCheckpointRDD[T: ClassTag]( - @transient sc: SparkContext, + sc: SparkContext, val checkpointPath: String) extends CheckpointRDD[T](sc) { @@ -144,7 +144,9 @@ private[spark] object ReliableCheckpointRDD extends Logging { } else { // Some other copy of this task must've finished before us and renamed it logInfo(s"Final output path $finalOutputPath already exists; not overwriting it") - fs.delete(tempOutputPath, false) + if (!fs.delete(tempOutputPath, false)) { + logWarning(s"Error deleting ${tempOutputPath}") + } } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala index 1df8eef5ff2b9..91cad6662e4d2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala @@ -28,7 +28,7 @@ import org.apache.spark.util.SerializableConfiguration * An implementation of checkpointing that writes the RDD data to reliable storage. * This allows drivers to be restarted on failure with previously computed state. */ -private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) +private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) extends RDDCheckpointData[T](rdd) with Logging { // The directory to which the associated RDD has been checkpointed to @@ -89,7 +89,7 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient rdd: RDD[ } -private[spark] object ReliableRDDCheckpointData { +private[spark] object ReliableRDDCheckpointData extends Logging { /** Return the path of the directory to which this RDD's checkpoint data is written. */ def checkpointPath(sc: SparkContext, rddId: Int): Option[Path] = { @@ -101,7 +101,9 @@ private[spark] object ReliableRDDCheckpointData { checkpointPath(sc, rddId).foreach { path => val fs = path.getFileSystem(sc.hadoopConfiguration) if (fs.exists(path)) { - fs.delete(path, true) + if (!fs.delete(path, true)) { + logWarning(s"Error deleting ${path.toString()}") + } } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 2dc47f95937cb..a013c3f66a3a8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import scala.reflect.ClassTag + import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.Serializer @@ -37,7 +39,7 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { */ // TODO: Make this return RDD[Product2[K, C]] or have some way to configure mutable pairs @DeveloperApi -class ShuffledRDD[K, V, C]( +class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( @transient var prev: RDD[_ <: Product2[K, V]], part: Partitioner) extends RDD[(K, C)](prev.context, Nil) { @@ -84,6 +86,12 @@ class ShuffledRDD[K, V, C]( Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i)) } + override def getPreferredLocations(partition: Partition): Seq[String] = { + val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] + tracker.getPreferredLocationsForShuffle(dep, partition.index) + } + override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index fa3fecc80cb63..0228c54e0511c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -39,7 +39,7 @@ import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Ut private[spark] class SqlNewHadoopPartition( rddId: Int, val index: Int, - @transient rawSplit: InputSplit with Writable) + rawSplit: InputSplit with Writable) extends SparkPartition { val serializableHadoopSplit = new SerializableWritable(rawSplit) @@ -61,9 +61,9 @@ private[spark] class SqlNewHadoopPartition( * changes based on [[org.apache.spark.rdd.HadoopRDD]]. */ private[spark] class SqlNewHadoopRDD[V: ClassTag]( - @transient sc : SparkContext, + sc : SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], - @transient initDriverSideJobFuncOpt: Option[Job => Unit], + @transient private val initDriverSideJobFuncOpt: Option[Job => Unit], initLocalJobFuncOpt: Option[Job => Unit], inputFormatClass: Class[_ <: InputFormat[Void, V]], valueClass: Class[V]) @@ -86,7 +86,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( if (isDriverSide) { initDriverSideJobFuncOpt.map(f => f(job)) } - job.getConfiguration + SparkHadoopUtil.get.getConfigurationFromJobContext(job) } private val jobTrackerId: String = { diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 9a4fa301b06e3..25ec685eff5ab 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -63,15 +63,17 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( } override def getDependencies: Seq[Dependency[_]] = { - Seq(rdd1, rdd2).map { rdd => + def rddDependency[T1: ClassTag, T2: ClassTag](rdd: RDD[_ <: Product2[T1, T2]]) + : Dependency[_] = { if (rdd.partitioner == Some(part)) { logDebug("Adding one-to-one dependency with " + rdd) new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency(rdd, part, serializer) + new ShuffleDependency[T1, T2, Any](rdd, part, serializer) } } + Seq(rddDependency[K, V](rdd1), rddDependency[K, W](rdd2)) } override def getPartitions: Array[Partition] = { @@ -105,7 +107,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( seq } } - def integrate(depNum: Int, op: Product2[K, V] => Unit) = { + def integrate(depNum: Int, op: Product2[K, V] => Unit): Unit = { dependencies(depNum) match { case oneToOneDependency: OneToOneDependency[_] => val dependencyPartition = partition.narrowDeps(depNum).get.split diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 3986645350a82..66cf4369da2ef 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -37,9 +37,9 @@ import org.apache.spark.util.Utils */ private[spark] class UnionPartition[T: ClassTag]( idx: Int, - @transient rdd: RDD[T], + @transient private val rdd: RDD[T], val parentRddIndex: Int, - @transient parentRddPartitionIndex: Int) + @transient private val parentRddPartitionIndex: Int) extends Partition { var parentPartition: Partition = rdd.partitions(parentRddPartitionIndex) diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index b3c64394abc76..70bf04de6400d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.util.Utils private[spark] class ZippedPartitionsPartition( idx: Int, - @transient rdds: Seq[RDD[_]], + @transient private val rdds: Seq[RDD[_]], @transient val preferredLocations: Seq[String]) extends Partition { diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index e277ae28d588f..32931d59acb18 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -37,7 +37,7 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) * @tparam T parent RDD item type */ private[spark] -class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, Long)](prev) { +class ZippedWithIndexRDD[T: ClassTag](prev: RDD[T]) extends RDD[(T, Long)](prev) { /** The start index of each partition. */ @transient private val startIndices: Array[Long] = { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala new file mode 100644 index 0000000000000..eb0b26947f504 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcAddress.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import org.apache.spark.util.Utils + + +/** + * Address for an RPC environment, with hostname and port. + */ +private[spark] case class RpcAddress(host: String, port: Int) { + + def hostPort: String = host + ":" + port + + /** Returns a string in the form of "spark://host:port". */ + def toSparkURL: String = "spark://" + hostPort + + override def toString: String = hostPort +} + + +private[spark] object RpcAddress { + + /** Return the [[RpcAddress]] represented by `uri`. */ + def fromURIString(uri: String): RpcAddress = { + val uriObj = new java.net.URI(uri) + RpcAddress(uriObj.getHost, uriObj.getPort) + } + + /** Returns the [[RpcAddress]] encoded in the form of "spark://host:port" */ + def fromSparkURL(sparkUrl: String): RpcAddress = { + val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) + RpcAddress(host, port) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala index 3e5b64265e919..f527ec86ab7b2 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala @@ -37,5 +37,5 @@ private[spark] trait RpcCallContext { /** * The sender of this message. */ - def sender: RpcEndpointRef + def senderAddress: RpcAddress } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index dfcbc51cdf616..0ba95169529e6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -28,20 +28,6 @@ private[spark] trait RpcEnvFactory { def create(config: RpcEnvConfig): RpcEnv } -/** - * A trait that requires RpcEnv thread-safely sending messages to it. - * - * Thread-safety means processing of one message happens before processing of the next message by - * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a - * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the - * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. - * - * However, there is no guarantee that the same thread will be executing the same - * [[ThreadSafeRpcEndpoint]] for different messages. - */ -private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint - - /** * An end point for the RPC that defines what functions to trigger given a message. * @@ -101,38 +87,39 @@ private[spark] trait RpcEndpoint { } /** - * Invoked before [[RpcEndpoint]] starts to handle any message. + * Invoked when `remoteAddress` is connected to the current node. */ - def onStart(): Unit = { + def onConnected(remoteAddress: RpcAddress): Unit = { // By default, do nothing. } /** - * Invoked when [[RpcEndpoint]] is stopping. + * Invoked when `remoteAddress` is lost. */ - def onStop(): Unit = { + def onDisconnected(remoteAddress: RpcAddress): Unit = { // By default, do nothing. } /** - * Invoked when `remoteAddress` is connected to the current node. + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. */ - def onConnected(remoteAddress: RpcAddress): Unit = { + def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { // By default, do nothing. } /** - * Invoked when `remoteAddress` is lost. + * Invoked before [[RpcEndpoint]] starts to handle any message. */ - def onDisconnected(remoteAddress: RpcAddress): Unit = { + def onStart(): Unit = { // By default, do nothing. } /** - * Invoked when some network error happens in the connection between the current node and - * `remoteAddress`. + * Invoked when [[RpcEndpoint]] is stopping. `self` will be `null` in this method and you cannot + * use it to send or ask messages. */ - def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + def onStop(): Unit = { // By default, do nothing. } @@ -146,3 +133,16 @@ private[spark] trait RpcEndpoint { } } } + +/** + * A trait that requires RpcEnv thread-safely sending messages to it. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a + * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the + * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same + * [[ThreadSafeRpcEndpoint]] for different messages. + */ +private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala new file mode 100644 index 0000000000000..d177881fb3053 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc + +import org.apache.spark.SparkException + +private[rpc] class RpcEndpointNotFoundException(uri: String) + extends SparkException(s"Cannot find endpoint: $uri") diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 7409ac8859991..f25710bb5bd6e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SparkException, Logging, SparkConf} /** * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. */ -private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) +private[spark] abstract class RpcEndpointRef(conf: SparkConf) extends Serializable with Logging { private[this] val maxRetries = RpcUtils.numRetries(conf) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 29debe8081308..2c4a8b9a0a878 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -17,12 +17,7 @@ package org.apache.spark.rpc -import java.net.URI -import java.util.concurrent.TimeoutException - -import scala.concurrent.{Awaitable, Await, Future} -import scala.concurrent.duration._ -import scala.language.postfixOps +import scala.concurrent.Future import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.util.{RpcUtils, Utils} @@ -35,9 +30,10 @@ import org.apache.spark.util.{RpcUtils, Utils} private[spark] object RpcEnv { private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { - // Add more RpcEnv implementations here - val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") - val rpcEnvName = conf.get("spark.rpc", "akka") + val rpcEnvNames = Map( + "akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory", + "netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory") + val rpcEnvName = conf.get("spark.rpc", "netty") val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory] } @@ -52,7 +48,6 @@ private[spark] object RpcEnv { val config = RpcEnvConfig(conf, name, host, port, securityManager) getRpcEnvFactory(conf).create(config) } - } @@ -98,15 +93,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri)) } - /** - * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName` - * asynchronously. - */ - def asyncSetupEndpointRef( - systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = { - asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName)) - } - /** * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`. * This is a blocking action. @@ -154,144 +140,3 @@ private[spark] case class RpcEnvConfig( host: String, port: Int, securityManager: SecurityManager) - - -/** - * Represents a host and port. - */ -private[spark] case class RpcAddress(host: String, port: Int) { - // TODO do we need to add the type of RpcEnv in the address? - - val hostPort: String = host + ":" + port - - override val toString: String = hostPort - - def toSparkURL: String = "spark://" + hostPort -} - - -private[spark] object RpcAddress { - - /** - * Return the [[RpcAddress]] represented by `uri`. - */ - def fromURI(uri: URI): RpcAddress = { - RpcAddress(uri.getHost, uri.getPort) - } - - /** - * Return the [[RpcAddress]] represented by `uri`. - */ - def fromURIString(uri: String): RpcAddress = { - fromURI(new java.net.URI(uri)) - } - - def fromSparkURL(sparkUrl: String): RpcAddress = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - RpcAddress(host, port) - } -} - - -/** - * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. - */ -private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) - extends TimeoutException(message) { initCause(cause) } - - -/** - * Associates a timeout with a description so that a when a TimeoutException occurs, additional - * context about the timeout can be amended to the exception message. - * @param duration timeout duration in seconds - * @param timeoutProp the configuration property that controls this timeout - */ -private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String) - extends Serializable { - - /** Amends the standard message of TimeoutException to include the description */ - private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { - new RpcTimeoutException(te.getMessage() + ". This timeout is controlled by " + timeoutProp, te) - } - - /** - * PartialFunction to match a TimeoutException and add the timeout description to the message - * - * @note This can be used in the recover callback of a Future to add to a TimeoutException - * Example: - * val timeout = new RpcTimeout(5 millis, "short timeout") - * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) - */ - def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { - // The exception has already been converted to a RpcTimeoutException so just raise it - case rte: RpcTimeoutException => throw rte - // Any other TimeoutException get converted to a RpcTimeoutException with modified message - case te: TimeoutException => throw createRpcTimeoutException(te) - } - - /** - * Wait for the completed result and return it. If the result is not available within this - * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. - * @param awaitable the `Awaitable` to be awaited - * @throws RpcTimeoutException if after waiting for the specified time `awaitable` - * is still not ready - */ - def awaitResult[T](awaitable: Awaitable[T]): T = { - try { - Await.result(awaitable, duration) - } catch addMessageIfTimeout - } -} - - -private[spark] object RpcTimeout { - - /** - * Lookup the timeout property in the configuration and create - * a RpcTimeout with the property key in the description. - * @param conf configuration properties containing the timeout - * @param timeoutProp property key for the timeout in seconds - * @throws NoSuchElementException if property is not set - */ - def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = { - val timeout = { conf.getTimeAsSeconds(timeoutProp) seconds } - new RpcTimeout(timeout, timeoutProp) - } - - /** - * Lookup the timeout property in the configuration and create - * a RpcTimeout with the property key in the description. - * Uses the given default value if property is not set - * @param conf configuration properties containing the timeout - * @param timeoutProp property key for the timeout in seconds - * @param defaultValue default timeout value in seconds if property not found - */ - def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = { - val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue) seconds } - new RpcTimeout(timeout, timeoutProp) - } - - /** - * Lookup prioritized list of timeout properties in the configuration - * and create a RpcTimeout with the first set property key in the - * description. - * Uses the given default value if property is not set - * @param conf configuration properties containing the timeout - * @param timeoutPropList prioritized list of property keys for the timeout in seconds - * @param defaultValue default timeout value in seconds if no properties found - */ - def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = { - require(timeoutPropList.nonEmpty) - - // Find the first set property or use the default value with the first property - val itr = timeoutPropList.iterator - var foundProp: Option[(String, String)] = None - while (itr.hasNext && foundProp.isEmpty){ - val propKey = itr.next() - conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } - } - val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) - val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds } - new RpcTimeout(timeout, finalProp._1) - } -} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala new file mode 100644 index 0000000000000..285786ebf9f1b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.util.concurrent.TimeoutException + +import scala.concurrent.{Awaitable, Await} +import scala.concurrent.duration._ + +import org.apache.spark.SparkConf +import org.apache.spark.util.Utils + + +/** + * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + */ +private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) + extends TimeoutException(message) { initCause(cause) } + + +/** + * Associates a timeout with a description so that a when a TimeoutException occurs, additional + * context about the timeout can be amended to the exception message. + * + * @param duration timeout duration in seconds + * @param timeoutProp the configuration property that controls this timeout + */ +private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String) + extends Serializable { + + /** Amends the standard message of TimeoutException to include the description */ + private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { + new RpcTimeoutException(te.getMessage + ". This timeout is controlled by " + timeoutProp, te) + } + + /** + * PartialFunction to match a TimeoutException and add the timeout description to the message + * + * @note This can be used in the recover callback of a Future to add to a TimeoutException + * Example: + * val timeout = new RpcTimeout(5 millis, "short timeout") + * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) + */ + def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { + // The exception has already been converted to a RpcTimeoutException so just raise it + case rte: RpcTimeoutException => throw rte + // Any other TimeoutException get converted to a RpcTimeoutException with modified message + case te: TimeoutException => throw createRpcTimeoutException(te) + } + + /** + * Wait for the completed result and return it. If the result is not available within this + * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. + * @param awaitable the `Awaitable` to be awaited + * @throws RpcTimeoutException if after waiting for the specified time `awaitable` + * is still not ready + */ + def awaitResult[T](awaitable: Awaitable[T]): T = { + try { + Await.result(awaitable, duration) + } catch addMessageIfTimeout + } +} + + +private[spark] object RpcTimeout { + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @throws NoSuchElementException if property is not set + */ + def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp).seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @param defaultValue default timeout value in seconds if property not found + */ + def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue).seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup prioritized list of timeout properties in the configuration + * and create a RpcTimeout with the first set property key in the + * description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutPropList prioritized list of property keys for the timeout in seconds + * @param defaultValue default timeout value in seconds if no properties found + */ + def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = { + require(timeoutPropList.nonEmpty) + + // Find the first set property or use the default value with the first property + val itr = timeoutPropList.iterator + var foundProp: Option[(String, String)] = None + while (itr.hasNext && foundProp.isEmpty){ + val propKey = itr.next() + conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } + } + val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) + val timeout = { Utils.timeStringAsSeconds(finalProp._2).seconds } + new RpcTimeout(timeout, finalProp._1) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index fc17542abf81d..3fad595a0d0b0 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -39,10 +39,6 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} * * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and * remove Akka from the dependencies. - * - * @param actorSystem - * @param conf - * @param boundPort */ private[spark] class AkkaRpcEnv private[akka] ( val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) @@ -87,9 +83,9 @@ private[spark] class AkkaRpcEnv private[akka] ( override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { @volatile var endpointRef: AkkaRpcEndpointRef = null - // Use lazy because the Actor needs to use `endpointRef`. + // Use defered function because the Actor needs to use `endpointRef`. // So `actorRef` should be created after assigning `endpointRef`. - lazy val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { + val actorRef = () => actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { assert(endpointRef != null) @@ -166,9 +162,9 @@ private[spark] class AkkaRpcEnv private[akka] ( _sender ! AkkaMessage(response, false) } - // Some RpcEndpoints need to know the sender's address - override val sender: RpcEndpointRef = - new AkkaRpcEndpointRef(defaultAddress, _sender, conf) + // Use "lazy" because most of RpcEndpoints don't need "senderAddress" + override lazy val senderAddress: RpcAddress = + new AkkaRpcEndpointRef(defaultAddress, _sender, conf).address }) } else { endpoint.receive @@ -272,13 +268,20 @@ private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging } private[akka] class AkkaRpcEndpointRef( - @transient defaultAddress: RpcAddress, - @transient _actorRef: => ActorRef, - @transient conf: SparkConf, - @transient initInConstructor: Boolean = true) + @transient private val defaultAddress: RpcAddress, + @transient private val _actorRef: () => ActorRef, + conf: SparkConf, + initInConstructor: Boolean) extends RpcEndpointRef(conf) with Logging { - lazy val actorRef = _actorRef + def this( + defaultAddress: RpcAddress, + _actorRef: ActorRef, + conf: SparkConf) = { + this(defaultAddress, () => _actorRef, conf, true) + } + + lazy val actorRef = _actorRef() override lazy val address: RpcAddress = { val akkaAddress = actorRef.path.address diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala new file mode 100644 index 0000000000000..7bf44a6565b61 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.concurrent.Promise +import scala.util.control.NonFatal + +import org.apache.spark.{SparkException, Logging} +import org.apache.spark.network.client.RpcResponseCallback +import org.apache.spark.rpc._ +import org.apache.spark.util.ThreadUtils + +/** + * A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s). + */ +private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { + + private class EndpointData( + val name: String, + val endpoint: RpcEndpoint, + val ref: NettyRpcEndpointRef) { + val inbox = new Inbox(ref, endpoint) + } + + private val endpoints = new ConcurrentHashMap[String, EndpointData] + private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef] + + // Track the receivers whose inboxes may contain messages. + private val receivers = new LinkedBlockingQueue[EndpointData] + + /** + * True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced + * immediately. + */ + @GuardedBy("this") + private var stopped = false + + def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = { + val addr = new RpcEndpointAddress(nettyEnv.address.host, nettyEnv.address.port, name) + val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv) + synchronized { + if (stopped) { + throw new IllegalStateException("RpcEnv has been stopped") + } + if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) { + throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name") + } + val data = endpoints.get(name) + endpointRefs.put(data.endpoint, data.ref) + receivers.offer(data) // for the OnStart message + } + endpointRef + } + + def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointRefs.get(endpoint) + + def removeRpcEndpointRef(endpoint: RpcEndpoint): Unit = endpointRefs.remove(endpoint) + + // Should be idempotent + private def unregisterRpcEndpoint(name: String): Unit = { + val data = endpoints.remove(name) + if (data != null) { + data.inbox.stop() + receivers.offer(data) // for the OnStop message + } + // Don't clean `endpointRefs` here because it's possible that some messages are being processed + // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via + // `removeRpcEndpointRef`. + } + + def stop(rpcEndpointRef: RpcEndpointRef): Unit = { + synchronized { + if (stopped) { + // This endpoint will be stopped by Dispatcher.stop() method. + return + } + unregisterRpcEndpoint(rpcEndpointRef.name) + } + } + + /** + * Send a message to all registered [[RpcEndpoint]]s in this process. + * + * This can be used to make network events known to all end points (e.g. "a new node connected"). + */ + def postToAll(message: InboxMessage): Unit = { + val iter = endpoints.keySet().iterator() + while (iter.hasNext) { + val name = iter.next + postMessage( + name, + _ => message, + () => { logWarning(s"Drop $message because $name has been stopped") }) + } + } + + /** Posts a message sent by a remote endpoint. */ + def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { + def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { + val rpcCallContext = + new RemoteNettyRpcCallContext( + nettyEnv, sender, callback, message.senderAddress, message.needReply) + ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext) + } + + def onEndpointStopped(): Unit = { + callback.onFailure( + new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) + } + + postMessage(message.receiver.name, createMessage, onEndpointStopped) + } + + /** Posts a message sent by a local endpoint. */ + def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = { + def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { + val rpcCallContext = + new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p) + ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext) + } + + def onEndpointStopped(): Unit = { + p.tryFailure( + new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) + } + + postMessage(message.receiver.name, createMessage, onEndpointStopped) + } + + /** + * Posts a message to a specific endpoint. + * + * @param endpointName name of the endpoint. + * @param createMessageFn function to create the message. + * @param callbackIfStopped callback function if the endpoint is stopped. + */ + private def postMessage( + endpointName: String, + createMessageFn: NettyRpcEndpointRef => InboxMessage, + callbackIfStopped: () => Unit): Unit = { + val shouldCallOnStop = synchronized { + val data = endpoints.get(endpointName) + if (stopped || data == null) { + true + } else { + data.inbox.post(createMessageFn(data.ref)) + receivers.offer(data) + false + } + } + if (shouldCallOnStop) { + // We don't need to call `onStop` in the `synchronized` block + callbackIfStopped() + } + } + + def stop(): Unit = { + synchronized { + if (stopped) { + return + } + stopped = true + } + // Stop all endpoints. This will queue all endpoints for processing by the message loops. + endpoints.keySet().asScala.foreach(unregisterRpcEndpoint) + // Enqueue a message that tells the message loops to stop. + receivers.offer(PoisonPill) + threadpool.shutdown() + } + + def awaitTermination(): Unit = { + threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS) + } + + /** + * Return if the endpoint exists + */ + def verify(name: String): Boolean = { + endpoints.containsKey(name) + } + + /** Thread pool used for dispatching messages. */ + private val threadpool: ThreadPoolExecutor = { + val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads", + Runtime.getRuntime.availableProcessors()) + val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop") + for (i <- 0 until numThreads) { + pool.execute(new MessageLoop) + } + pool + } + + /** Message loop used for dispatching messages. */ + private class MessageLoop extends Runnable { + override def run(): Unit = { + try { + while (true) { + try { + val data = receivers.take() + if (data == PoisonPill) { + // Put PoisonPill back so that other MessageLoops can see it. + receivers.offer(PoisonPill) + return + } + data.inbox.process(Dispatcher.this) + } catch { + case NonFatal(e) => logError(e.getMessage, e) + } + } + } catch { + case ie: InterruptedException => // exit + } + } + } + + /** A poison endpoint that indicates MessageLoop should exit its message loop. */ + private val PoisonPill = new EndpointData(null, null, null) +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala new file mode 100644 index 0000000000000..c72b588db57fe --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import javax.annotation.concurrent.GuardedBy + +import scala.util.control.NonFatal + +import com.google.common.annotations.VisibleForTesting + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint} + + +private[netty] sealed trait InboxMessage + +private[netty] case class ContentMessage( + senderAddress: RpcAddress, + content: Any, + needReply: Boolean, + context: NettyRpcCallContext) extends InboxMessage + +private[netty] case object OnStart extends InboxMessage + +private[netty] case object OnStop extends InboxMessage + +/** A message to tell all endpoints that a remote process has connected. */ +private[netty] case class RemoteProcessConnected(remoteAddress: RpcAddress) extends InboxMessage + +/** A message to tell all endpoints that a remote process has disconnected. */ +private[netty] case class RemoteProcessDisconnected(remoteAddress: RpcAddress) extends InboxMessage + +/** A message to tell all endpoints that a network error has happened. */ +private[netty] case class RemoteProcessConnectionError(cause: Throwable, remoteAddress: RpcAddress) + extends InboxMessage + +/** + * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely. + */ +private[netty] class Inbox( + val endpointRef: NettyRpcEndpointRef, + val endpoint: RpcEndpoint) + extends Logging { + + inbox => // Give this an alias so we can use it more clearly in closures. + + @GuardedBy("this") + protected val messages = new java.util.LinkedList[InboxMessage]() + + /** True if the inbox (and its associated endpoint) is stopped. */ + @GuardedBy("this") + private var stopped = false + + /** Allow multiple threads to process messages at the same time. */ + @GuardedBy("this") + private var enableConcurrent = false + + /** The number of threads processing messages for this inbox. */ + @GuardedBy("this") + private var numActiveThreads = 0 + + // OnStart should be the first message to process + inbox.synchronized { + messages.add(OnStart) + } + + /** + * Process stored messages. + */ + def process(dispatcher: Dispatcher): Unit = { + var message: InboxMessage = null + inbox.synchronized { + if (!enableConcurrent && numActiveThreads != 0) { + return + } + message = messages.poll() + if (message != null) { + numActiveThreads += 1 + } else { + return + } + } + while (true) { + safelyCall(endpoint) { + message match { + case ContentMessage(_sender, content, needReply, context) => + // The partial function to call + val pf = if (needReply) endpoint.receiveAndReply(context) else endpoint.receive + try { + pf.applyOrElse[Any, Unit](content, { msg => + throw new SparkException(s"Unsupported message $message from ${_sender}") + }) + if (!needReply) { + context.finish() + } + } catch { + case NonFatal(e) => + if (needReply) { + // If the sender asks a reply, we should send the error back to the sender + context.sendFailure(e) + } else { + context.finish() + } + // Throw the exception -- this exception will be caught by the safelyCall function. + // The endpoint's onError function will be called. + throw e + } + + case OnStart => + endpoint.onStart() + if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { + inbox.synchronized { + if (!stopped) { + enableConcurrent = true + } + } + } + + case OnStop => + val activeThreads = inbox.synchronized { inbox.numActiveThreads } + assert(activeThreads == 1, + s"There should be only a single active thread but found $activeThreads threads.") + dispatcher.removeRpcEndpointRef(endpoint) + endpoint.onStop() + assert(isEmpty, "OnStop should be the last message") + + case RemoteProcessConnected(remoteAddress) => + endpoint.onConnected(remoteAddress) + + case RemoteProcessDisconnected(remoteAddress) => + endpoint.onDisconnected(remoteAddress) + + case RemoteProcessConnectionError(cause, remoteAddress) => + endpoint.onNetworkError(cause, remoteAddress) + } + } + + inbox.synchronized { + // "enableConcurrent" will be set to false after `onStop` is called, so we should check it + // every time. + if (!enableConcurrent && numActiveThreads != 1) { + // If we are not the only one worker, exit + numActiveThreads -= 1 + return + } + message = messages.poll() + if (message == null) { + numActiveThreads -= 1 + return + } + } + } + } + + def post(message: InboxMessage): Unit = inbox.synchronized { + if (stopped) { + // We already put "OnStop" into "messages", so we should drop further messages + onDrop(message) + } else { + messages.add(message) + false + } + } + + def stop(): Unit = inbox.synchronized { + // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last + // message + if (!stopped) { + // We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only + // thread that is processing messages. So `RpcEndpoint.onStop` can release its resources + // safely. + enableConcurrent = false + stopped = true + messages.add(OnStop) + // Note: The concurrent events in messages will be processed one by one. + } + } + + def isEmpty: Boolean = inbox.synchronized { messages.isEmpty } + + /** Called when we are dropping a message. Test cases override this to test message dropping. */ + @VisibleForTesting + protected def onDrop(message: InboxMessage): Unit = { + logWarning(s"Drop $message because $endpointRef is stopped") + } + + /** + * Calls action closure, and calls the endpoint's onError function in the case of exceptions. + */ + private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { + try action catch { + case NonFatal(e) => + try endpoint.onError(e) catch { + case NonFatal(ee) => logError(s"Ignoring error", ee) + } + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala new file mode 100644 index 0000000000000..21d5bb4923d1b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import scala.concurrent.Promise + +import org.apache.spark.Logging +import org.apache.spark.network.client.RpcResponseCallback +import org.apache.spark.rpc.{RpcAddress, RpcCallContext} + +private[netty] abstract class NettyRpcCallContext( + endpointRef: NettyRpcEndpointRef, + override val senderAddress: RpcAddress, + needReply: Boolean) + extends RpcCallContext with Logging { + + protected def send(message: Any): Unit + + override def reply(response: Any): Unit = { + if (needReply) { + send(AskResponse(endpointRef, response)) + } else { + throw new IllegalStateException( + s"Cannot send $response to the sender because the sender does not expect a reply") + } + } + + override def sendFailure(e: Throwable): Unit = { + if (needReply) { + send(AskResponse(endpointRef, RpcFailure(e))) + } else { + logError(e.getMessage, e) + throw new IllegalStateException( + "Cannot send reply to the sender because the sender won't handle it") + } + } + + def finish(): Unit = { + if (!needReply) { + send(Ack(endpointRef)) + } + } +} + +/** + * If the sender and the receiver are in the same process, the reply can be sent back via `Promise`. + */ +private[netty] class LocalNettyRpcCallContext( + endpointRef: NettyRpcEndpointRef, + senderAddress: RpcAddress, + needReply: Boolean, + p: Promise[Any]) + extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + + override protected def send(message: Any): Unit = { + p.success(message) + } +} + +/** + * A [[RpcCallContext]] that will call [[RpcResponseCallback]] to send the reply back. + */ +private[netty] class RemoteNettyRpcCallContext( + nettyEnv: NettyRpcEnv, + endpointRef: NettyRpcEndpointRef, + callback: RpcResponseCallback, + senderAddress: RpcAddress, + needReply: Boolean) + extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + + override protected def send(message: Any): Unit = { + val reply = nettyEnv.serialize(message) + callback.onSuccess(reply) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala new file mode 100644 index 0000000000000..284284eb805b7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -0,0 +1,511 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc.netty + +import java.io._ +import java.net.{InetSocketAddress, URI} +import java.nio.ByteBuffer +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicBoolean +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable +import scala.concurrent.{Future, Promise} +import scala.reflect.ClassTag +import scala.util.{DynamicVariable, Failure, Success} +import scala.util.control.NonFatal + +import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.network.TransportContext +import org.apache.spark.network.client._ +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} +import org.apache.spark.network.server._ +import org.apache.spark.rpc._ +import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance} +import org.apache.spark.util.{ThreadUtils, Utils} + +private[netty] class NettyRpcEnv( + val conf: SparkConf, + javaSerializerInstance: JavaSerializerInstance, + host: String, + securityManager: SecurityManager) extends RpcEnv(conf) with Logging { + + // Override numConnectionsPerPeer to 1 for RPC. + private val transportConf = SparkTransportConf.fromSparkConf( + conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"), + conf.getInt("spark.rpc.io.threads", 0)) + + private val dispatcher: Dispatcher = new Dispatcher(this) + + private val transportContext = + new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this)) + + private val clientFactory = { + val bootstraps: java.util.List[TransportClientBootstrap] = + if (securityManager.isAuthenticationEnabled()) { + java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, + securityManager.isSaslEncryptionEnabled())) + } else { + java.util.Collections.emptyList[TransportClientBootstrap] + } + transportContext.createClientFactory(bootstraps) + } + + val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") + + // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool + // to implement non-blocking send/ask. + // TODO: a non-blocking TransportClientFactory.createClient in future + private[netty] val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( + "netty-rpc-connection", + conf.getInt("spark.rpc.connect.threads", 64)) + + @volatile private var server: TransportServer = _ + + private val stopped = new AtomicBoolean(false) + + /** + * A map for [[RpcAddress]] and [[Outbox]]. When we are connecting to a remote [[RpcAddress]], + * we just put messages to its [[Outbox]] to implement a non-blocking `send` method. + */ + private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]() + + /** + * Remove the address's Outbox and stop it. + */ + private[netty] def removeOutbox(address: RpcAddress): Unit = { + val outbox = outboxes.remove(address) + if (outbox != null) { + outbox.stop() + } + } + + def start(port: Int): Unit = { + val bootstraps: java.util.List[TransportServerBootstrap] = + if (securityManager.isAuthenticationEnabled()) { + java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager)) + } else { + java.util.Collections.emptyList() + } + server = transportContext.createServer(port, bootstraps) + dispatcher.registerRpcEndpoint( + RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher)) + } + + override lazy val address: RpcAddress = { + require(server != null, "NettyRpcEnv has not yet started") + RpcAddress(host, server.getPort) + } + + override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { + dispatcher.registerRpcEndpoint(name, endpoint) + } + + def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { + val addr = RpcEndpointAddress(uri) + val endpointRef = new NettyRpcEndpointRef(conf, addr, this) + val verifier = new NettyRpcEndpointRef( + conf, RpcEndpointAddress(addr.host, addr.port, RpcEndpointVerifier.NAME), this) + verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find => + if (find) { + Future.successful(endpointRef) + } else { + Future.failed(new RpcEndpointNotFoundException(uri)) + } + }(ThreadUtils.sameThread) + } + + override def stop(endpointRef: RpcEndpointRef): Unit = { + require(endpointRef.isInstanceOf[NettyRpcEndpointRef]) + dispatcher.stop(endpointRef) + } + + private def postToOutbox(address: RpcAddress, message: OutboxMessage): Unit = { + val targetOutbox = { + val outbox = outboxes.get(address) + if (outbox == null) { + val newOutbox = new Outbox(this, address) + val oldOutbox = outboxes.putIfAbsent(address, newOutbox) + if (oldOutbox == null) { + newOutbox + } else { + oldOutbox + } + } else { + outbox + } + } + if (stopped.get) { + // It's possible that we put `targetOutbox` after stopping. So we need to clean it. + outboxes.remove(address) + targetOutbox.stop() + } else { + targetOutbox.send(message) + } + } + + private[netty] def send(message: RequestMessage): Unit = { + val remoteAddr = message.receiver.address + if (remoteAddr == address) { + // Message to a local RPC endpoint. + val promise = Promise[Any]() + dispatcher.postLocalMessage(message, promise) + promise.future.onComplete { + case Success(response) => + val ack = response.asInstanceOf[Ack] + logTrace(s"Received ack from ${ack.sender}") + case Failure(e) => + logWarning(s"Exception when sending $message", e) + }(ThreadUtils.sameThread) + } else { + // Message to a remote RPC endpoint. + postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback { + + override def onFailure(e: Throwable): Unit = { + logWarning(s"Exception when sending $message", e) + } + + override def onSuccess(response: Array[Byte]): Unit = { + val ack = deserialize[Ack](response) + logDebug(s"Receive ack from ${ack.sender}") + } + })) + } + } + + private[netty] def createClient(address: RpcAddress): TransportClient = { + clientFactory.createClient(address.host, address.port) + } + + private[netty] def ask(message: RequestMessage): Future[Any] = { + val promise = Promise[Any]() + val remoteAddr = message.receiver.address + if (remoteAddr == address) { + val p = Promise[Any]() + dispatcher.postLocalMessage(message, p) + p.future.onComplete { + case Success(response) => + val reply = response.asInstanceOf[AskResponse] + if (reply.reply.isInstanceOf[RpcFailure]) { + if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { + logWarning(s"Ignore failure: ${reply.reply}") + } + } else if (!promise.trySuccess(reply.reply)) { + logWarning(s"Ignore message: ${reply}") + } + case Failure(e) => + if (!promise.tryFailure(e)) { + logWarning("Ignore Exception", e) + } + }(ThreadUtils.sameThread) + } else { + postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback { + + override def onFailure(e: Throwable): Unit = { + if (!promise.tryFailure(e)) { + logWarning("Ignore Exception", e) + } + } + + override def onSuccess(response: Array[Byte]): Unit = { + val reply = deserialize[AskResponse](response) + if (reply.reply.isInstanceOf[RpcFailure]) { + if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { + logWarning(s"Ignore failure: ${reply.reply}") + } + } else if (!promise.trySuccess(reply.reply)) { + logWarning(s"Ignore message: ${reply}") + } + } + })) + } + promise.future + } + + private[netty] def serialize(content: Any): Array[Byte] = { + val buffer = javaSerializerInstance.serialize(content) + java.util.Arrays.copyOfRange( + buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit) + } + + private[netty] def deserialize[T: ClassTag](bytes: Array[Byte]): T = { + deserialize { () => + javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes)) + } + } + + override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = { + dispatcher.getRpcEndpointRef(endpoint) + } + + override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = + new RpcEndpointAddress(address.host, address.port, endpointName).toString + + override def shutdown(): Unit = { + cleanup() + } + + override def awaitTermination(): Unit = { + dispatcher.awaitTermination() + } + + private def cleanup(): Unit = { + if (!stopped.compareAndSet(false, true)) { + return + } + + val iter = outboxes.values().iterator() + while (iter.hasNext()) { + val outbox = iter.next() + outboxes.remove(outbox.address) + outbox.stop() + } + if (timeoutScheduler != null) { + timeoutScheduler.shutdownNow() + } + if (server != null) { + server.close() + } + if (clientFactory != null) { + clientFactory.close() + } + if (dispatcher != null) { + dispatcher.stop() + } + if (clientConnectionExecutor != null) { + clientConnectionExecutor.shutdownNow() + } + } + + override def deserialize[T](deserializationAction: () => T): T = { + NettyRpcEnv.currentEnv.withValue(this) { + deserializationAction() + } + } +} + +private[netty] object NettyRpcEnv extends Logging { + + /** + * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]]. + * Use `currentEnv` to wrap the deserialization codes. E.g., + * + * {{{ + * NettyRpcEnv.currentEnv.withValue(this) { + * your deserialization codes + * } + * }}} + */ + private[netty] val currentEnv = new DynamicVariable[NettyRpcEnv](null) +} + +private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { + + def create(config: RpcEnvConfig): RpcEnv = { + val sparkConf = config.conf + // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support + // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance + val javaSerializerInstance = + new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance] + val nettyEnv = + new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager) + val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => + nettyEnv.start(actualPort) + (nettyEnv, actualPort) + } + try { + Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1 + } catch { + case NonFatal(e) => + nettyEnv.shutdown() + throw e + } + } +} + +private[netty] class NettyRpcEndpointRef(@transient private val conf: SparkConf) + extends RpcEndpointRef(conf) with Serializable with Logging { + + @transient @volatile private var nettyEnv: NettyRpcEnv = _ + + @transient @volatile private var _address: RpcEndpointAddress = _ + + def this(conf: SparkConf, _address: RpcEndpointAddress, nettyEnv: NettyRpcEnv) { + this(conf) + this._address = _address + this.nettyEnv = nettyEnv + } + + override def address: RpcAddress = _address.toRpcAddress + + private def readObject(in: ObjectInputStream): Unit = { + in.defaultReadObject() + _address = in.readObject().asInstanceOf[RpcEndpointAddress] + nettyEnv = NettyRpcEnv.currentEnv.value + } + + private def writeObject(out: ObjectOutputStream): Unit = { + out.defaultWriteObject() + out.writeObject(_address) + } + + override def name: String = _address.name + + override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { + val promise = Promise[Any]() + val timeoutCancelable = nettyEnv.timeoutScheduler.schedule(new Runnable { + override def run(): Unit = { + promise.tryFailure(new TimeoutException("Cannot receive any reply in " + timeout.duration)) + } + }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) + val f = nettyEnv.ask(RequestMessage(nettyEnv.address, this, message, true)) + f.onComplete { v => + timeoutCancelable.cancel(true) + if (!promise.tryComplete(v)) { + logWarning(s"Ignore message $v") + } + }(ThreadUtils.sameThread) + promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) + } + + override def send(message: Any): Unit = { + require(message != null, "Message is null") + nettyEnv.send(RequestMessage(nettyEnv.address, this, message, false)) + } + + override def toString: String = s"NettyRpcEndpointRef(${_address})" + + def toURI: URI = new URI(s"spark://${_address}") + + final override def equals(that: Any): Boolean = that match { + case other: NettyRpcEndpointRef => _address == other._address + case _ => false + } + + final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode() +} + +/** + * The message that is sent from the sender to the receiver. + */ +private[netty] case class RequestMessage( + senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any, needReply: Boolean) + +/** + * The base trait for all messages that are sent back from the receiver to the sender. + */ +private[netty] trait ResponseMessage + +/** + * The reply for `ask` from the receiver side. + */ +private[netty] case class AskResponse(sender: NettyRpcEndpointRef, reply: Any) + extends ResponseMessage + +/** + * A message to send back to the receiver side. It's necessary because [[TransportClient]] only + * clean the resources when it receives a reply. + */ +private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessage + +/** + * A response that indicates some failure happens in the receiver side. + */ +private[netty] case class RpcFailure(e: Throwable) + +/** + * Maintain the mapping relations between client addresses and [[RpcEnv]] addresses, broadcast + * network events and forward messages to [[Dispatcher]]. + */ +private[netty] class NettyRpcHandler( + dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging { + + private type ClientAddress = RpcAddress + private type RemoteEnvAddress = RpcAddress + + // Store all client addresses and their NettyRpcEnv addresses. + // TODO: Is this even necessary? + @GuardedBy("this") + private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]() + + override def receive( + client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { + val requestMessage = nettyEnv.deserialize[RequestMessage](message) + val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] + assert(addr != null) + val remoteEnvAddress = requestMessage.senderAddress + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + + // TODO: Can we add connection callback (channel registered) to the underlying framework? + // A variable to track whether we should dispatch the RemoteProcessConnected message. + var dispatchRemoteProcessConnected = false + synchronized { + if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) { + // clientAddr connects at the first time, fire "RemoteProcessConnected" + dispatchRemoteProcessConnected = true + } + } + if (dispatchRemoteProcessConnected) { + dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress)) + } + dispatcher.postRemoteMessage(requestMessage, callback) + } + + override def getStreamManager: StreamManager = new OneForOneStreamManager + + override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { + val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] + if (addr != null) { + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + val broadcastMessage = + synchronized { + remoteAddresses.get(clientAddr).map(RemoteProcessConnectionError(cause, _)) + } + if (broadcastMessage.isEmpty) { + logError(cause.getMessage, cause) + } else { + dispatcher.postToAll(broadcastMessage.get) + } + } else { + // If the channel is closed before connecting, its remoteAddress will be null. + // See java.net.Socket.getRemoteSocketAddress + // Because we cannot get a RpcAddress, just log it + logError("Exception before connecting to the client", cause) + } + } + + override def connectionTerminated(client: TransportClient): Unit = { + val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] + if (addr != null) { + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + nettyEnv.removeOutbox(clientAddr) + val messageOpt: Option[RemoteProcessDisconnected] = + synchronized { + remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress => + remoteAddresses -= clientAddr + Some(RemoteProcessDisconnected(remoteEnvAddress)) + } + } + messageOpt.foreach(dispatcher.postToAll) + } else { + // If the channel is closed before connecting, its remoteAddress will be null. In this case, + // we can ignore it since we don't fire "Associated". + // See java.net.Socket.getRemoteSocketAddress + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala new file mode 100644 index 0000000000000..7d9d593b36241 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.util.concurrent.Callable +import javax.annotation.concurrent.GuardedBy + +import scala.util.control.NonFatal + +import org.apache.spark.SparkException +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.rpc.RpcAddress + +private[netty] case class OutboxMessage(content: Array[Byte], callback: RpcResponseCallback) + +private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { + + outbox => // Give this an alias so we can use it more clearly in closures. + + @GuardedBy("this") + private val messages = new java.util.LinkedList[OutboxMessage] + + @GuardedBy("this") + private var client: TransportClient = null + + /** + * connectFuture points to the connect task. If there is no connect task, connectFuture will be + * null. + */ + @GuardedBy("this") + private var connectFuture: java.util.concurrent.Future[Unit] = null + + @GuardedBy("this") + private var stopped = false + + /** + * If there is any thread draining the message queue + */ + @GuardedBy("this") + private var draining = false + + /** + * Send a message. If there is no active connection, cache it and launch a new connection. If + * [[Outbox]] is stopped, the sender will be notified with a [[SparkException]]. + */ + def send(message: OutboxMessage): Unit = { + val dropped = synchronized { + if (stopped) { + true + } else { + messages.add(message) + false + } + } + if (dropped) { + message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped")) + } else { + drainOutbox() + } + } + + /** + * Drain the message queue. If there is other draining thread, just exit. If the connection has + * not been established, launch a task in the `nettyEnv.clientConnectionExecutor` to setup the + * connection. + */ + private def drainOutbox(): Unit = { + var message: OutboxMessage = null + synchronized { + if (stopped) { + return + } + if (connectFuture != null) { + // We are connecting to the remote address, so just exit + return + } + if (client == null) { + // There is no connect task but client is null, so we need to launch the connect task. + launchConnectTask() + return + } + if (draining) { + // There is some thread draining, so just exit + return + } + message = messages.poll() + if (message == null) { + return + } + draining = true + } + while (true) { + try { + val _client = synchronized { client } + if (_client != null) { + _client.sendRpc(message.content, message.callback) + } else { + assert(stopped == true) + } + } catch { + case NonFatal(e) => + handleNetworkFailure(e) + return + } + synchronized { + if (stopped) { + return + } + message = messages.poll() + if (message == null) { + draining = false + return + } + } + } + } + + private def launchConnectTask(): Unit = { + connectFuture = nettyEnv.clientConnectionExecutor.submit(new Callable[Unit] { + + override def call(): Unit = { + try { + val _client = nettyEnv.createClient(address) + outbox.synchronized { + client = _client + if (stopped) { + closeClient() + } + } + } catch { + case ie: InterruptedException => + // exit + return + case NonFatal(e) => + outbox.synchronized { connectFuture = null } + handleNetworkFailure(e) + return + } + outbox.synchronized { connectFuture = null } + // It's possible that no thread is draining now. If we don't drain here, we cannot send the + // messages until the next message arrives. + drainOutbox() + } + }) + } + + /** + * Stop [[Inbox]] and notify the waiting messages with the cause. + */ + private def handleNetworkFailure(e: Throwable): Unit = { + synchronized { + assert(connectFuture == null) + if (stopped) { + return + } + stopped = true + closeClient() + } + // Remove this Outbox from nettyEnv so that the further messages will create a new Outbox along + // with a new connection + nettyEnv.removeOutbox(address) + + // Notify the connection failure for the remaining messages + // + // We always check `stopped` before updating messages, so here we can make sure no thread will + // update messages and it's safe to just drain the queue. + var message = messages.poll() + while (message != null) { + message.callback.onFailure(e) + message = messages.poll() + } + assert(messages.isEmpty) + } + + private def closeClient(): Unit = synchronized { + // Not sure if `client.close` is idempotent. Just for safety. + if (client != null) { + client.close() + } + client = null + } + + /** + * Stop [[Outbox]]. The remaining messages in the [[Outbox]] will be notified with a + * [[SparkException]]. + */ + def stop(): Unit = { + synchronized { + if (stopped) { + return + } + stopped = true + if (connectFuture != null) { + connectFuture.cancel(true) + } + closeClient() + } + + // We always check `stopped` before updating messages, so here we can make sure no thread will + // update messages and it's safe to just drain the queue. + var message = messages.poll() + while (message != null) { + message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped")) + message = messages.poll() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala new file mode 100644 index 0000000000000..87b6236936817 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import org.apache.spark.SparkException +import org.apache.spark.rpc.RpcAddress + +/** + * An address identifier for an RPC endpoint. + * + * @param host host name of the remote process. + * @param port the port the remote RPC environment binds to. + * @param name name of the remote endpoint. + */ +private[netty] case class RpcEndpointAddress(host: String, port: Int, name: String) { + + def toRpcAddress: RpcAddress = RpcAddress(host, port) + + override val toString = s"spark://$name@$host:$port" +} + +private[netty] object RpcEndpointAddress { + + def apply(sparkUrl: String): RpcEndpointAddress = { + try { + val uri = new java.net.URI(sparkUrl) + val host = uri.getHost + val port = uri.getPort + val name = uri.getUserInfo + if (uri.getScheme != "spark" || + host == null || + port < 0 || + name == null || + (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null + uri.getFragment != null || + uri.getQuery != null) { + throw new SparkException("Invalid Spark URL: " + sparkUrl) + } + RpcEndpointAddress(host, port, name) + } catch { + case e: java.net.URISyntaxException => + throw new SparkException("Invalid Spark URL: " + sparkUrl, e) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala new file mode 100644 index 0000000000000..99f20da2d66aa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} + +/** + * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an [[RpcEndpoint]] exists. + * + * This is used when setting up a remote endpoint reference. + */ +private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher) + extends RpcEndpoint { + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RpcEndpointVerifier.CheckExistence(name) => context.reply(dispatcher.verify(name)) + } +} + +private[netty] object RpcEndpointVerifier { + val NAME = "endpoint-verifier" + + /** A message used to ask the remote [[RpcEndpointVerifier]] if an [[RpcEndpoint]] exists. */ + case class CheckExistence(name: String) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index 11d123eec43ca..146cfb9ba8037 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -34,12 +34,27 @@ class AccumulableInfo private[spark] ( override def equals(other: Any): Boolean = other match { case acc: AccumulableInfo => this.id == acc.id && this.name == acc.name && - this.update == acc.update && this.value == acc.value + this.update == acc.update && this.value == acc.value && + this.internal == acc.internal case _ => false } + + override def hashCode(): Int = { + val state = Seq(id, name, update, value, internal) + state.map(_.hashCode).reduceLeft(31 * _ + _) + } } object AccumulableInfo { + def apply( + id: Long, + name: String, + update: Option[String], + value: String, + internal: Boolean): AccumulableInfo = { + new AccumulableInfo(id, name, update, value, internal) + } + def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = { new AccumulableInfo(id, name, update, value, internal = false) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index 50a69379412d2..a3d2db31301b3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -23,18 +23,42 @@ import org.apache.spark.TaskContext import org.apache.spark.util.CallSite /** - * Tracks information about an active job in the DAGScheduler. + * A running job in the DAGScheduler. Jobs can be of two types: a result job, which computes a + * ResultStage to execute an action, or a map-stage job, which computes the map outputs for a + * ShuffleMapStage before any downstream stages are submitted. The latter is used for adaptive + * query planning, to look at map output statistics before submitting later stages. We distinguish + * between these two types of jobs using the finalStage field of this class. + * + * Jobs are only tracked for "leaf" stages that clients directly submitted, through DAGScheduler's + * submitJob or submitMapStage methods. However, either type of job may cause the execution of + * other earlier stages (for RDDs in the DAG it depends on), and multiple jobs may share some of + * these previous stages. These dependencies are managed inside DAGScheduler. + * + * @param jobId A unique ID for this job. + * @param finalStage The stage that this job computes (either a ResultStage for an action or a + * ShuffleMapStage for submitMapStage). + * @param callSite Where this job was initiated in the user's program (shown on UI). + * @param listener A listener to notify if tasks in this job finish or the job fails. + * @param properties Scheduling properties attached to the job, such as fair scheduler pool name. */ private[spark] class ActiveJob( val jobId: Int, - val finalStage: ResultStage, - val func: (TaskContext, Iterator[_]) => _, - val partitions: Array[Int], + val finalStage: Stage, val callSite: CallSite, val listener: JobListener, val properties: Properties) { - val numPartitions = partitions.length + /** + * Number of partitions we need to compute for this job. Note that result stages may not need + * to compute all partitions in their target RDD, for actions like first() and lookup(). + */ + val numPartitions = finalStage match { + case r: ResultStage => r.partitions.length + case m: ShuffleMapStage => m.rdd.partitions.length + } + + /** Which partitions of the stage have finished */ val finished = Array.fill[Boolean](numPartitions)(false) + var numFinished = 0 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index daf9b0f95273e..995862ece5944 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -45,17 +45,65 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of * stages for each job, keeps track of which RDDs and stage outputs are materialized, and finds a * minimal schedule to run the job. It then submits stages as TaskSets to an underlying - * TaskScheduler implementation that runs them on the cluster. + * TaskScheduler implementation that runs them on the cluster. A TaskSet contains fully independent + * tasks that can run right away based on the data that's already on the cluster (e.g. map output + * files from previous stages), though it may fail if this data becomes unavailable. * - * In addition to coming up with a DAG of stages, this class also determines the preferred + * Spark stages are created by breaking the RDD graph at shuffle boundaries. RDD operations with + * "narrow" dependencies, like map() and filter(), are pipelined together into one set of tasks + * in each stage, but operations with shuffle dependencies require multiple stages (one to write a + * set of map output files, and another to read those files after a barrier). In the end, every + * stage will have only shuffle dependencies on other stages, and may compute multiple operations + * inside it. The actual pipelining of these operations happens in the RDD.compute() functions of + * various RDDs (MappedRDD, FilteredRDD, etc). + * + * In addition to coming up with a DAG of stages, the DAGScheduler also determines the preferred * locations to run each task on, based on the current cache status, and passes these to the * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * + * When looking through this code, there are several key concepts: + * + * - Jobs (represented by [[ActiveJob]]) are the top-level work items submitted to the scheduler. + * For example, when the user calls an action, like count(), a job will be submitted through + * submitJob. Each Job may require the execution of multiple stages to build intermediate data. + * + * - Stages ([[Stage]]) are sets of tasks that compute intermediate results in jobs, where each + * task computes the same function on partitions of the same RDD. Stages are separated at shuffle + * boundaries, which introduce a barrier (where we must wait for the previous stage to finish to + * fetch outputs). There are two types of stages: [[ResultStage]], for the final stage that + * executes an action, and [[ShuffleMapStage]], which writes map output files for a shuffle. + * Stages are often shared across multiple jobs, if these jobs reuse the same RDDs. + * + * - Tasks are individual units of work, each sent to one machine. + * + * - Cache tracking: the DAGScheduler figures out which RDDs are cached to avoid recomputing them + * and likewise remembers which shuffle map stages have already produced output files to avoid + * redoing the map side of a shuffle. + * + * - Preferred locations: the DAGScheduler also computes where to run each task in a stage based + * on the preferred locations of its underlying RDDs, or the location of cached or shuffle data. + * + * - Cleanup: all data structures are cleared when the running jobs that depend on them finish, + * to prevent memory leaks in a long-running application. + * + * To recover from failures, the same stage might need to run multiple times, which are called + * "attempts". If the TaskScheduler reports that a task failed because a map output file from a + * previous stage was lost, the DAGScheduler resubmits that lost stage. This is detected through a + * CompletionEvent with FetchFailed, or an ExecutorLost event. The DAGScheduler will wait a small + * amount of time to see whether other nodes or tasks fail, then resubmit TaskSets for any lost + * stage(s) that compute the missing tasks. As part of this process, we might also have to create + * Stage objects for old (finished) stages where we previously cleaned up the Stage object. Since + * tasks from the old attempt of a stage could still be running, care must be taken to map any + * events received in the correct Stage object. + * * Here's a checklist to use when making or reviewing changes to this class: * + * - All data structures should be cleared when the jobs involving them end to avoid indefinite + * accumulation of state in long-running programs. + * * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to * include the new structure. This will help to catch memory leaks. */ @@ -136,22 +184,6 @@ class DAGScheduler( private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) - // Flag to control if reduce tasks are assigned preferred locations - private val shuffleLocalityEnabled = - sc.getConf.getBoolean("spark.shuffle.reduceLocality.enabled", true) - // Number of map, reduce tasks above which we do not assign preferred locations - // based on map output sizes. We limit the size of jobs for which assign preferred locations - // as computing the top locations by size becomes expensive. - private[this] val SHUFFLE_PREF_MAP_THRESHOLD = 1000 - // NOTE: This should be less than 2000 as we use HighlyCompressedMapStatus beyond that - private[this] val SHUFFLE_PREF_REDUCE_THRESHOLD = 1000 - - // Fraction of total map output that must be at a location for it to considered as a preferred - // location for a reduce task. - // Making this larger will focus on fewer locations where most data can be read locally, but - // may lead to more delay in scheduling if those locations are busy. - private[scheduler] val REDUCER_PREF_LOCS_FRACTION = 0.2 - /** * Called by the TaskSetManager to report task's starting. */ @@ -250,11 +282,12 @@ class DAGScheduler( case Some(stage) => stage case None => // We are going to register ancestor shuffle dependencies - registerShuffleDependencies(shuffleDep, firstJobId) + getAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep => + shuffleToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId) + } // Then register current shuffleDep val stage = newOrUsedShuffleStage(shuffleDep, firstJobId) shuffleToMapStage(shuffleDep.shuffleId) = stage - stage } } @@ -294,12 +327,12 @@ class DAGScheduler( */ private def newResultStage( rdd: RDD[_], - numTasks: Int, + func: (TaskContext, Iterator[_]) => _, + partitions: Array[Int], jobId: Int, callSite: CallSite): ResultStage = { val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) - val stage: ResultStage = new ResultStage(id, rdd, numTasks, parentStages, jobId, callSite) - + val stage = new ResultStage(id, rdd, func, partitions, parentStages, jobId, callSite) stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) stage @@ -320,10 +353,12 @@ class DAGScheduler( if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) - for (i <- 0 until locs.length) { - stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing + (0 until locs.length).foreach { i => + if (locs(i) ne null) { + // locs(i) will be null if missing + stage.addOutputLoc(i, locs(i)) + } } - stage.numAvailableOutputs = locs.count(_ != null) } else { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown @@ -365,16 +400,6 @@ class DAGScheduler( parents.toList } - /** Find ancestor missing shuffle dependencies and register into shuffleToMapStage */ - private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], firstJobId: Int) { - val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) - while (parentsWithNoMapStage.nonEmpty) { - val currentShufDep = parentsWithNoMapStage.pop() - val stage = newOrUsedShuffleStage(currentShufDep, firstJobId) - shuffleToMapStage(currentShufDep.shuffleId) = stage - } - } - /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { val parents = new Stack[ShuffleDependency[_, _, _]] @@ -391,11 +416,9 @@ class DAGScheduler( if (!shuffleToMapStage.contains(shufDep.shuffleId)) { parents.push(shufDep) } - - waitingForVisit.push(shufDep.rdd) case _ => - waitingForVisit.push(dep.rdd) } + waitingForVisit.push(dep.rdd) } } } @@ -511,12 +534,25 @@ class DAGScheduler( jobIdToStageIds -= job.jobId jobIdToActiveJob -= job.jobId activeJobs -= job - job.finalStage.resultOfJob = None + job.finalStage match { + case r: ResultStage => + r.resultOfJob = None + case m: ShuffleMapStage => + m.mapStageJobs = m.mapStageJobs.filter(_ != job) + } } /** - * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object + * Submit an action job to the scheduler and get a JobWaiter object back. The JobWaiter object * can be used to block until the the job finishes executing or can be used to cancel the job. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name */ def submitJob[T, U]( rdd: RDD[T], @@ -535,6 +571,7 @@ class DAGScheduler( val jobId = nextJobId.getAndIncrement() if (partitions.size == 0) { + // Return immediately if the job is running 0 tasks return new JobWaiter[U](this, jobId, 0, resultHandler) } @@ -547,6 +584,18 @@ class DAGScheduler( waiter } + /** + * Run an action job on the given RDD and pass all the results to the resultHandler function as + * they arrive. Throws an exception if the job fials, or returns normally if successful. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ def runJob[T, U]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, @@ -570,6 +619,17 @@ class DAGScheduler( } } + /** + * Run an approximate job on the given RDD and pass all the results to an ApproximateEvaluator + * as they arrive. Returns a partial result object from the evaluator. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param evaluator [[ApproximateEvaluator]] to receive the partial results + * @param callSite where in the user program this job was called + * @param timeout maximum time to wait for the job, in milliseconds + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ def runApproximateJob[T, U, R]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, @@ -586,6 +646,41 @@ class DAGScheduler( listener.awaitResult() // Will throw an exception if the job fails } + /** + * Submit a shuffle map stage to run independently and get a JobWaiter object back. The waiter + * can be used to block until the the job finishes executing or can be used to cancel the job. + * This method is used for adaptive query planning, to run map stages and look at statistics + * about their outputs before submitting downstream stages. + * + * @param dependency the ShuffleDependency to run a map stage for + * @param callback function called with the result of the job, which in this case will be a + * single MapOutputStatistics object showing how much data was produced for each partition + * @param callSite where in the user program this job was submitted + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ + def submitMapStage[K, V, C]( + dependency: ShuffleDependency[K, V, C], + callback: MapOutputStatistics => Unit, + callSite: CallSite, + properties: Properties): JobWaiter[MapOutputStatistics] = { + + val rdd = dependency.rdd + val jobId = nextJobId.getAndIncrement() + if (rdd.partitions.length == 0) { + throw new SparkException("Can't run submitMapStage on RDD with 0 partitions") + } + + // We create a JobWaiter with only one "task", which will be marked as complete when the whole + // map stage has completed, and will be passed the MapOutputStatistics for that stage. + // This makes it easier to avoid race conditions between the user code and the map output + // tracker that might result if we told the user the stage had finished, but then they queries + // the map output tracker and some node failures had caused the output statistics to be lost. + val waiter = new JobWaiter(this, jobId, 1, (i: Int, r: MapOutputStatistics) => callback(r)) + eventProcessLoop.post(MapStageSubmitted( + jobId, dependency, callSite, waiter, SerializationUtils.clone(properties))) + waiter + } + /** * Cancel a job that is running or waiting in the queue. */ @@ -594,6 +689,9 @@ class DAGScheduler( eventProcessLoop.post(JobCancelled(jobId)) } + /** + * Cancel all jobs in the given job group ID. + */ def cancelJobGroup(groupId: String): Unit = { logInfo("Asked to cancel job group " + groupId) eventProcessLoop.post(JobGroupCancelled(groupId)) @@ -731,31 +829,77 @@ class DAGScheduler( try { // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. - finalStage = newResultStage(finalRDD, partitions.length, jobId, callSite) + finalStage = newResultStage(finalRDD, func, partitions, jobId, callSite) } catch { case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) listener.jobFailed(e) return } - if (finalStage != null) { - val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) - clearCacheLocs() - logInfo("Got job %s (%s) with %d output partitions".format( - job.jobId, callSite.shortForm, partitions.length)) - logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") - logInfo("Parents of final stage: " + finalStage.parents) - logInfo("Missing parents: " + getMissingParentStages(finalStage)) - val jobSubmissionTime = clock.getTimeMillis() - jobIdToActiveJob(jobId) = job - activeJobs += job - finalStage.resultOfJob = Some(job) - val stageIds = jobIdToStageIds(jobId).toArray - val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) - listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) - submitStage(finalStage) + + val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) + clearCacheLocs() + logInfo("Got job %s (%s) with %d output partitions".format( + job.jobId, callSite.shortForm, partitions.length)) + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.resultOfJob = Some(job) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) + + submitWaitingStages() + } + + private[scheduler] def handleMapStageSubmitted(jobId: Int, + dependency: ShuffleDependency[_, _, _], + callSite: CallSite, + listener: JobListener, + properties: Properties) { + // Submitting this map stage might still require the creation of some parent stages, so make + // sure that happens. + var finalStage: ShuffleMapStage = null + try { + // New stage creation may throw an exception if, for example, jobs are run on a + // HadoopRDD whose underlying HDFS files have been deleted. + finalStage = getShuffleMapStage(dependency, jobId) + } catch { + case e: Exception => + logWarning("Creating new stage failed due to exception - job: " + jobId, e) + listener.jobFailed(e) + return } + + val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) + clearCacheLocs() + logInfo("Got map stage job %s (%s) with %d output partitions".format( + jobId, callSite.shortForm, dependency.rdd.partitions.size)) + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.mapStageJobs = job :: finalStage.mapStageJobs + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) + + // If the whole stage has already finished, tell the listener and remove it + if (finalStage.isAvailable) { + markMapStageJobAsFinished(job, mapOutputTracker.getStatistics(dependency)) + } + submitWaitingStages() } @@ -786,27 +930,15 @@ class DAGScheduler( private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry - stage.pendingTasks.clear() + stage.pendingPartitions.clear() // First figure out the indexes of partition ids to compute. - val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = { - stage match { - case stage: ShuffleMapStage => - val allPartitions = 0 until stage.numPartitions - val filteredPartitions = allPartitions.filter { id => stage.outputLocs(id).isEmpty } - (allPartitions, filteredPartitions) - case stage: ResultStage => - val job = stage.resultOfJob.get - val allPartitions = 0 until job.numPartitions - val filteredPartitions = allPartitions.filter { id => !job.finished(id) } - (allPartitions, filteredPartitions) - } - } + val partitionsToCompute: Seq[Int] = stage.findMissingPartitions() // Create internal accumulators if the stage has no accumulators initialized. // Reset internal accumulators only if this stage is not partially submitted // Otherwise, we may override existing accumulator values from some tasks - if (stage.internalAccumulators.isEmpty || allPartitions == partitionsToCompute) { + if (stage.internalAccumulators.isEmpty || stage.numPartitions == partitionsToCompute.size) { stage.resetInternalAccumulators() } @@ -825,7 +957,7 @@ class DAGScheduler( case s: ResultStage => val job = s.resultOfJob.get partitionsToCompute.map { id => - val p = job.partitions(id) + val p = s.partitions(id) (id, getPreferredLocs(stage.rdd, p)) }.toMap } @@ -855,7 +987,7 @@ class DAGScheduler( case stage: ShuffleMapStage => closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array() case stage: ResultStage => - closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func): AnyRef).array() + closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array() } taskBinary = sc.broadcast(taskBinaryBytes) @@ -886,7 +1018,7 @@ class DAGScheduler( case stage: ResultStage => val job = stage.resultOfJob.get partitionsToCompute.map { id => - val p: Int = job.partitions(id) + val p: Int = stage.partitions(id) val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, @@ -902,8 +1034,8 @@ class DAGScheduler( if (tasks.size > 0) { logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") - stage.pendingTasks ++= tasks - logDebug("New pending tasks: " + stage.pendingTasks) + stage.pendingPartitions ++= tasks.map(_.partitionId) + logDebug("New pending partitions: " + stage.pendingPartitions) taskScheduler.submitTasks(new TaskSet( tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) @@ -970,8 +1102,11 @@ class DAGScheduler( val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) - outputCommitCoordinator.taskCompleted(stageId, task.partitionId, - event.taskInfo.attempt, event.reason) + outputCommitCoordinator.taskCompleted( + stageId, + task.partitionId, + event.taskInfo.attemptNumber, // this is a task attempt number + event.reason) // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. @@ -991,7 +1126,7 @@ class DAGScheduler( case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, event.reason, event.taskInfo, event.taskMetrics)) - stage.pendingTasks -= task + stage.pendingPartitions -= task.partitionId task match { case rt: ResultTask[_, _] => // Cast to ResultStage here because it's part of the ResultTask @@ -1037,7 +1172,7 @@ class DAGScheduler( shuffleStage.addOutputLoc(smt.partitionId, status) } - if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { + if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") logInfo("running: " + runningStages) @@ -1052,45 +1187,36 @@ class DAGScheduler( // we registered these map outputs. mapOutputTracker.registerMapOutputs( shuffleStage.shuffleDep.shuffleId, - shuffleStage.outputLocs.map(list => if (list.isEmpty) null else list.head), + shuffleStage.outputLocs.map(_.headOption.orNull), changeEpoch = true) clearCacheLocs() - if (shuffleStage.outputLocs.contains(Nil)) { + + if (!shuffleStage.isAvailable) { // Some tasks had failed; let's resubmit this shuffleStage // TODO: Lower-level scheduler should also deal with this logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + ") because some of its tasks had failed: " + shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty) - .map(_._2).mkString(", ")) + .map(_._2).mkString(", ")) submitStage(shuffleStage) } else { - val newlyRunnable = new ArrayBuffer[Stage] - for (shuffleStage <- waitingStages) { - logInfo("Missing parents for " + shuffleStage + ": " + - getMissingParentStages(shuffleStage)) - } - for (shuffleStage <- waitingStages if getMissingParentStages(shuffleStage).isEmpty) - { - newlyRunnable += shuffleStage - } - waitingStages --= newlyRunnable - runningStages ++= newlyRunnable - for { - shuffleStage <- newlyRunnable.sortBy(_.id) - jobId <- activeJobForStage(shuffleStage) - } { - logInfo("Submitting " + shuffleStage + " (" + - shuffleStage.rdd + "), which is now runnable") - submitMissingTasks(shuffleStage, jobId) + // Mark any map-stage jobs waiting on this stage as finished + if (shuffleStage.mapStageJobs.nonEmpty) { + val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) + for (job <- shuffleStage.mapStageJobs) { + markMapStageJobAsFinished(job, stats) + } } } + + // Note: newly runnable stages will be submitted below when we submit waiting stages } - } + } case Resubmitted => logInfo("Resubmitted " + task + ", so marking it as still running") - stage.pendingTasks += task + stage.pendingPartitions += task.partitionId case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => val failedStage = stageIdToStage(task.stageId) @@ -1101,7 +1227,6 @@ class DAGScheduler( s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + s"(attempt ID ${failedStage.latestInfo.attemptId}) running") } else { - // It is likely that we receive multiple FetchFailed for a single stage (because we have // multiple tasks running concurrently on different executors). In that case, it is // possible the fetch failure has already been handled by the scheduler. @@ -1117,6 +1242,11 @@ class DAGScheduler( if (disallowStageRetryForTest) { abortStage(failedStage, "Fetch failure will not retry stage due to testing config", None) + } else if (failedStage.failedOnFetchAndShouldAbort(task.stageAttemptId)) { + abortStage(failedStage, s"$failedStage (${failedStage.name}) " + + s"has failed the maximum allowable number of " + + s"times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}. " + + s"Most recent failure reason: ${failureMessage}", None) } else if (failedStages.isEmpty) { // Don't schedule an event to resubmit failed stages if failed isn't empty, because // in that case the event will already have been scheduled. @@ -1182,7 +1312,7 @@ class DAGScheduler( // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { stage.removeOutputsOnExecutor(execId) - val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head) + val locs = stage.outputLocs.map(_.headOption.orNull) mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) } if (shuffleToMapStage.isEmpty) { @@ -1240,10 +1370,17 @@ class DAGScheduler( if (errorMessage.isEmpty) { logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) stage.latestInfo.completionTime = Some(clock.getTimeMillis()) + + // Clear failure count for this stage, now that it's succeeded. + // We only limit consecutive failures of stage attempts,so that if a stage is + // re-used many times in a long-running job, unrelated failures don't eventually cause the + // stage to be aborted. + stage.clearFailures() } else { stage.latestInfo.stageFailed(errorMessage.get) logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime)) } + outputCommitCoordinator.stageEnd(stage.id) listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) runningStages -= stage @@ -1407,28 +1544,24 @@ class DAGScheduler( return locs } } + case _ => } - // If the RDD has shuffle dependencies and shuffle locality is enabled, pick locations that - // have at least REDUCER_PREF_LOCS_FRACTION of data as preferred locations - if (shuffleLocalityEnabled && rdd.partitions.length < SHUFFLE_PREF_REDUCE_THRESHOLD) { - rdd.dependencies.foreach { - case s: ShuffleDependency[_, _, _] => - if (s.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD) { - // Get the preferred map output locations for this reducer - val topLocsForReducer = mapOutputTracker.getLocationsWithLargestOutputs(s.shuffleId, - partition, rdd.partitions.length, REDUCER_PREF_LOCS_FRACTION) - if (topLocsForReducer.nonEmpty) { - return topLocsForReducer.get.map(loc => TaskLocation(loc.host, loc.executorId)) - } - } - case _ => - } - } Nil } + /** Mark a map stage job as finished with the given output stats, and report to its listener. */ + def markMapStageJobAsFinished(job: ActiveJob, stats: MapOutputStatistics): Unit = { + // In map stage jobs, we only create a single "task", which is to finish all of the stage + // (including reusing any previous map outputs, etc); so we just mark task 0 as done + job.finished(0) = true + job.numFinished += 1 + job.listener.taskSucceeded(0, stats) + cleanupStateForJobAndIndependentStages(job) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) + } + def stop() { logInfo("Stopping DAGScheduler") messageScheduler.shutdownNow() @@ -1462,6 +1595,9 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) => dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) + case MapStageSubmitted(jobId, dependency, callSite, listener, properties) => + dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties) + case StageCancelled(stageId) => dagScheduler.handleStageCancellation(stageId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index f72a52e85dc15..dda3b6cc7f960 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -35,6 +35,7 @@ import org.apache.spark.util.CallSite */ private[scheduler] sealed trait DAGSchedulerEvent +/** A result-yielding job was submitted on a target RDD */ private[scheduler] case class JobSubmitted( jobId: Int, finalRDD: RDD[_], @@ -45,6 +46,15 @@ private[scheduler] case class JobSubmitted( properties: Properties = null) extends DAGSchedulerEvent +/** A map stage as submitted to run as a separate job */ +private[scheduler] case class MapStageSubmitted( + jobId: Int, + dependency: ShuffleDependency[_, _, _], + callSite: CallSite, + listener: JobListener, + properties: Properties = null) + extends DAGSchedulerEvent + private[scheduler] case class StageCancelled(stageId: Int) extends DAGSchedulerEvent private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 5a06ef02f5c57..000a021a528cf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -109,7 +109,9 @@ private[spark] class EventLoggingListener( if (shouldOverwrite && fileSystem.exists(path)) { logWarning(s"Event log $path already exists. Overwriting...") - fileSystem.delete(path, true) + if (!fileSystem.delete(path, true)) { + logWarning(s"Error deleting $path") + } } /* The Hadoop LocalFileSystem (r1.0.4) has known issues with syncing (HADOOP-7844). @@ -216,7 +218,9 @@ private[spark] class EventLoggingListener( if (fileSystem.exists(target)) { if (shouldOverwrite) { logWarning(s"Event log $target already exists. Overwriting...") - fileSystem.delete(target, true) + if (!fileSystem.delete(target, true)) { + logWarning(s"Error deleting $target") + } } else { throw new IOException("Target log file already exists (%s)".format(logPath)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 2bc43a9186449..0a98c69b89ea5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -23,16 +23,20 @@ import org.apache.spark.executor.ExecutorExitCode * Represents an explanation for a executor or whole slave failing or exiting. */ private[spark] -class ExecutorLossReason(val message: String) { +class ExecutorLossReason(val message: String) extends Serializable { override def toString: String = message } private[spark] -case class ExecutorExited(val exitCode: Int) - extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) { +case class ExecutorExited(exitCode: Int, isNormalExit: Boolean, reason: String) + extends ExecutorLossReason(reason) + +private[spark] object ExecutorExited { + def apply(exitCode: Int, isNormalExit: Boolean): ExecutorExited = { + ExecutorExited(exitCode, isNormalExit, ExecutorExitCode.explainExitCode(exitCode)) + } } private[spark] case class SlaveLost(_message: String = "Slave lost") - extends ExecutorLossReason(_message) { -} + extends ExecutorLossReason(_message) diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 5d926377ce86b..add0dedc03f44 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -25,7 +25,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint private sealed trait OutputCommitCoordinationMessage extends Serializable private case object StopCoordinator extends OutputCommitCoordinationMessage -private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttempt: Long) +private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int) /** * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" @@ -44,8 +44,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) var coordinatorRef: Option[RpcEndpointRef] = None private type StageId = Int - private type PartitionId = Long - private type TaskAttemptId = Long + private type PartitionId = Int + private type TaskAttemptNumber = Int /** * Map from active stages's id => partition id => task attempt with exclusive lock on committing @@ -57,7 +57,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() - private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]] + private type CommittersByStageMap = + mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptNumber]] /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. @@ -75,14 +76,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * * @param stage the stage number * @param partition the partition number - * @param attempt a unique identifier for this task attempt + * @param attemptNumber how many times this task has been attempted + * (see [[TaskContext.attemptNumber()]]) * @return true if this task is authorized to commit, false otherwise */ def canCommit( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId): Boolean = { - val msg = AskPermissionToCommitOutput(stage, partition, attempt) + attemptNumber: TaskAttemptNumber): Boolean = { + val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber) coordinatorRef match { case Some(endpointRef) => endpointRef.askWithRetry[Boolean](msg) @@ -95,7 +97,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Called by DAGScheduler private[scheduler] def stageStart(stage: StageId): Unit = synchronized { - authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptId]() + authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptNumber]() } // Called by DAGScheduler @@ -107,7 +109,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private[scheduler] def taskCompleted( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId, + attemptNumber: TaskAttemptNumber, reason: TaskEndReason): Unit = synchronized { val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") @@ -117,12 +119,12 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) case Success => // The task output has been committed successfully case denied: TaskCommitDenied => - logInfo( - s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt") + logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + + s"attempt: $attemptNumber") case otherReason => - if (authorizedCommitters.get(partition).exists(_ == attempt)) { - logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + - s" clearing lock") + if (authorizedCommitters.get(partition).exists(_ == attemptNumber)) { + logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + + s"partition=$partition) failed; clearing lock") authorizedCommitters.remove(partition) } } @@ -140,21 +142,23 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private[scheduler] def handleAskPermissionToCommit( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId): Boolean = synchronized { + attemptNumber: TaskAttemptNumber): Boolean = synchronized { authorizedCommittersByStage.get(stage) match { case Some(authorizedCommitters) => authorizedCommitters.get(partition) match { case Some(existingCommitter) => - logDebug(s"Denying $attempt to commit for stage=$stage, partition=$partition; " + - s"existingCommitter = $existingCommitter") + logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + + s"partition=$partition; existingCommitter = $existingCommitter") false case None => - logDebug(s"Authorizing $attempt to commit for stage=$stage, partition=$partition") - authorizedCommitters(partition) = attempt + logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + + s"partition=$partition") + authorizedCommitters(partition) = attemptNumber true } case None => - logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit") + logDebug(s"Stage $stage has completed, so not allowing attempt number $attemptNumber of" + + s"partition $partition to commit") false } } @@ -174,9 +178,9 @@ private[spark] object OutputCommitCoordinator { } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case AskPermissionToCommitOutput(stage, partition, taskAttempt) => + case AskPermissionToCommitOutput(stage, partition, attemptNumber) => context.reply( - outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) + outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 5821afea98982..551e39a81b695 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -83,8 +83,8 @@ private[spark] class Pool( null } - override def executorLost(executorId: String, host: String) { - schedulableQueue.asScala.foreach(_.executorLost(executorId, host)) + override def executorLost(executorId: String, host: String, reason: ExecutorLossReason) { + schedulableQueue.asScala.foreach(_.executorLost(executorId, host, reason)) } override def checkSpeculatableTasks(): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala index bf81b9aca4810..c1d86af7e8fb5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -17,24 +17,36 @@ package org.apache.spark.scheduler +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.util.CallSite /** - * The ResultStage represents the final stage in a job. + * ResultStages apply a function on some partitions of an RDD to compute the result of an action. + * The ResultStage object captures the function to execute, `func`, which will be applied to each + * partition, and the set of partition IDs, `partitions`. Some stages may not run on all partitions + * of the RDD, for actions like first() and lookup(). */ private[spark] class ResultStage( id: Int, rdd: RDD[_], - numTasks: Int, + val func: (TaskContext, Iterator[_]) => _, + val partitions: Array[Int], parents: List[Stage], firstJobId: Int, callSite: CallSite) - extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { + extends Stage(id, rdd, partitions.length, parents, firstJobId, callSite) { - // The active job for this result stage. Will be empty if the job has already finished - // (e.g., because the job was cancelled). + /** + * The active job for this result stage. Will be empty if the job has already finished + * (e.g., because the job was cancelled). + */ var resultOfJob: Option[ActiveJob] = None + override def findMissingPartitions(): Seq[Int] = { + val job = resultOfJob.get + (0 until job.numPartitions).filter(id => !job.finished(id)) + } + override def toString: String = "ResultStage " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index c4dc080e2b22b..fb693721a9cb6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -44,7 +44,7 @@ private[spark] class ResultTask[T, U]( stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, - @transient locs: Seq[TaskLocation], + locs: Seq[TaskLocation], val outputId: Int, internalAccumulators: Seq[Accumulator[Long]]) extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala index a87ef030e69c2..ab00bc8f0bf4e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala @@ -42,7 +42,7 @@ private[spark] trait Schedulable { def addSchedulable(schedulable: Schedulable): Unit def removeSchedulable(schedulable: Schedulable): Unit def getSchedulableByName(name: String): Schedulable - def executorLost(executorId: String, host: String): Unit + def executorLost(executorId: String, host: String, reason: ExecutorLossReason): Unit def checkSpeculatableTasks(): Boolean def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index 48d8d8e9c4b78..3832d99eddaef 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -23,7 +23,15 @@ import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** - * The ShuffleMapStage represents the intermediate stages in a job. + * ShuffleMapStages are intermediate stages in the execution DAG that produce data for a shuffle. + * They occur right before each shuffle operation, and might contain multiple pipelined operations + * before that (e.g. map and filter). When executed, they save map output files that can later be + * fetched by reduce tasks. The `shuffleDep` field describes the shuffle each stage is part of, + * and variables like `outputLocs` and `numAvailableOutputs` track how many map outputs are ready. + * + * ShuffleMapStages can also be submitted independently as jobs with DAGScheduler.submitMapStage. + * For such stages, the ActiveJobs that submitted them are tracked in `mapStageJobs`. Note that + * there can be multiple ActiveJobs trying to compute the same shuffle map stage. */ private[spark] class ShuffleMapStage( id: Int, @@ -37,12 +45,36 @@ private[spark] class ShuffleMapStage( override def toString: String = "ShuffleMapStage " + id + /** Running map-stage jobs that were submitted to execute this stage independently (if any) */ + var mapStageJobs: List[ActiveJob] = Nil + + /** + * Number of partitions that have shuffle outputs. + * When this reaches [[numPartitions]], this map stage is ready. + * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`. + */ var numAvailableOutputs: Int = 0 + /** + * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs. + * This should be the same as `outputLocs.contains(Nil)`. + */ def isAvailable: Boolean = numAvailableOutputs == numPartitions + /** + * List of [[MapStatus]] for each partition. The index of the array is the map partition id, + * and each value in the array is the list of possible [[MapStatus]] for a partition + * (a single task might run multiple times). + */ val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) + override def findMissingPartitions(): Seq[Int] = { + val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty) + assert(missing.size == numPartitions - numAvailableOutputs, + s"${missing.size} missing, expected ${numPartitions - numAvailableOutputs}") + missing + } + def addOutputLoc(partition: Int, status: MapStatus): Unit = { val prevList = outputLocs(partition) outputLocs(partition) = status :: prevList diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 1cf06856ffbc2..5ce4a484344f1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -24,29 +24,35 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.CallSite /** - * A stage is a set of independent tasks all computing the same function that need to run as part + * A stage is a set of parallel tasks all computing the same function that need to run as part * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run * by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the * DAGScheduler runs these stages in topological order. * * Each Stage can either be a shuffle map stage, in which case its tasks' results are input for - * another stage, or a result stage, in which case its tasks directly compute the action that - * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes - * that each output partition is on. + * other stage(s), or a result stage, in which case its tasks directly compute a Spark action + * (e.g. count(), save(), etc) by running a function on an RDD. For shuffle map stages, we also + * track the nodes that each output partition is on. * * Each Stage also has a firstJobId, identifying the job that first submitted the stage. When FIFO * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered * faster on failure. * - * The callSite provides a location in user code which relates to the stage. For a shuffle map - * stage, the callSite gives the user code that created the RDD being shuffled. For a result - * stage, the callSite gives the user code that executes the associated action (e.g. count()). - * - * A single stage can consist of multiple attempts. In that case, the latestInfo field will - * be updated for each attempt. + * Finally, a single stage can be re-executed in multiple attempts due to fault recovery. In that + * case, the Stage object will track multiple StageInfo objects to pass to listeners or the web UI. + * The latest one will be accessible through latestInfo. * + * @param id Unique stage ID + * @param rdd RDD that this stage runs on: for a shuffle map stage, it's the RDD we run map tasks + * on, while for a result stage, it's the target RDD that we ran an action on + * @param numTasks Total number of tasks in stage; result stages in particular may not need to + * compute all partitions, e.g. for first(), lookup(), and take(). + * @param parents List of stages that this stage depends on (through shuffle dependencies). + * @param firstJobId ID of the first job this stage was part of, for FIFO scheduling. + * @param callSite Location in the user program associated with this stage: either where the target + * RDD was created, for a shuffle map stage, or where the action for a result stage was called. */ -private[spark] abstract class Stage( +private[scheduler] abstract class Stage( val id: Int, val rdd: RDD[_], val numTasks: Int, @@ -55,12 +61,12 @@ private[spark] abstract class Stage( val callSite: CallSite) extends Logging { - val numPartitions = rdd.partitions.size + val numPartitions = rdd.partitions.length /** Set of jobs that this stage belongs to. */ val jobIds = new HashSet[Int] - var pendingTasks = new HashSet[Task[_]] + val pendingPartitions = new HashSet[Int] /** The ID to use for the next new attempt for this stage. */ private var nextAttemptId: Int = 0 @@ -92,6 +98,29 @@ private[spark] abstract class Stage( */ private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId) + /** + * Set of stage attempt IDs that have failed with a FetchFailure. We keep track of these + * failures in order to avoid endless retries if a stage keeps failing with a FetchFailure. + * We keep track of each attempt ID that has failed to avoid recording duplicate failures if + * multiple tasks from the same stage attempt fail (SPARK-5945). + */ + private val fetchFailedAttemptIds = new HashSet[Int] + + private[scheduler] def clearFailures() : Unit = { + fetchFailedAttemptIds.clear() + } + + /** + * Check whether we should abort the failedStage due to multiple consecutive fetch failures. + * + * This method updates the running set of failed stage attempts and returns + * true if the number of failures exceeds the allowable number of failures. + */ + private[scheduler] def failedOnFetchAndShouldAbort(stageAttemptId: Int): Boolean = { + fetchFailedAttemptIds.add(stageAttemptId) + fetchFailedAttemptIds.size >= Stage.MAX_CONSECUTIVE_FETCH_FAILURES + } + /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ def makeNewStageAttempt( numPartitionsToCompute: Int, @@ -109,4 +138,12 @@ private[spark] abstract class Stage( case stage: Stage => stage != null && stage.id == id case _ => false } + + /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ + def findMissingPartitions(): Seq[Int] +} + +private[scheduler] object Stage { + // The number of consecutive failures allowed before a stage is aborted + val MAX_CONSECUTIVE_FETCH_FAILURES = 4 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 132a9ced77700..f113c2b1b8433 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -29,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi class TaskInfo( val taskId: Long, val index: Int, - val attempt: Int, + val attemptNumber: Int, val launchTime: Long, val executorId: String, val host: String, @@ -95,7 +95,10 @@ class TaskInfo( } } - def id: String = s"$index.$attempt" + @deprecated("Use attemptNumber", "1.6.0") + def attempt: Int = attemptNumber + + def id: String = s"$index.$attemptNumber" def duration: Long = { if (!finished) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala index da07ce2c6ea49..1b65926f5c749 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala @@ -67,7 +67,7 @@ private[spark] object TaskLocation { if (hstr.equals(str)) { new HostTaskLocation(str) } else { - new HostTaskLocation(hstr) + new HDFSCacheTaskLocation(hstr) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 1705e7f962de2..1c7bfe89c02ac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -332,7 +332,8 @@ private[spark] class TaskSchedulerImpl( // We lost this entire executor, so remember that it's gone val execId = taskIdToExecutorId(tid) if (activeExecutorIds.contains(execId)) { - removeExecutor(execId) + removeExecutor(execId, + SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) failedExecutor = Some(execId) } } @@ -464,7 +465,7 @@ private[spark] class TaskSchedulerImpl( if (activeExecutorIds.contains(executorId)) { val hostPort = executorIdToHost(executorId) logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) - removeExecutor(executorId) + removeExecutor(executorId, reason) failedExecutor = Some(executorId) } else { // We may get multiple executorLost() calls with different loss reasons. For example, one @@ -482,7 +483,7 @@ private[spark] class TaskSchedulerImpl( } /** Remove an executor from all our data structures and mark it as lost */ - private def removeExecutor(executorId: String) { + private def removeExecutor(executorId: String, reason: ExecutorLossReason) { activeExecutorIds -= executorId val host = executorIdToHost(executorId) val execs = executorsByHost.getOrElse(host, new HashSet) @@ -497,7 +498,7 @@ private[spark] class TaskSchedulerImpl( } } executorIdToHost -= executorId - rootPool.executorLost(executorId, host) + rootPool.executorLost(executorId, host, reason) } def executorAdded(execId: String, host: String) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 818b95d67f6be..987800d3d1f1e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -177,14 +177,11 @@ private[spark] class TaskSetManager( var emittedTaskSizeWarning = false - /** - * Add a task to all the pending-task lists that it should be on. If readding is set, we are - * re-adding the task so only include it in each list if it's not already there. - */ - private def addPendingTask(index: Int, readding: Boolean = false) { - // Utility method that adds `index` to a list only if readding=false or it's not already there + /** Add a task to all the pending-task lists that it should be on. */ + private def addPendingTask(index: Int) { + // Utility method that adds `index` to a list only if it's not already there def addTo(list: ArrayBuffer[Int]) { - if (!readding || !list.contains(index)) { + if (!list.contains(index)) { list += index } } @@ -219,9 +216,7 @@ private[spark] class TaskSetManager( addTo(pendingTasksWithNoPrefs) } - if (!readding) { - allPendingTasks += index // No point scanning this whole list to find the old task there - } + allPendingTasks += index // No point scanning this whole list to find the old task there } /** @@ -487,8 +482,8 @@ private[spark] class TaskSetManager( // a good proxy to task serialization time. // val timeTaken = clock.getTime() - startTime val taskName = s"task ${info.id} in stage ${taskSet.id}" - logInfo("Starting %s (TID %d, %s, %s, %d bytes)".format( - taskName, taskId, host, taskLocality, serializedTask.limit)) + logInfo(s"Starting $taskName (TID $taskId, $host, partition ${task.partitionId}," + + s"$taskLocality, ${serializedTask.limit} bytes)") sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, @@ -709,6 +704,11 @@ private[spark] class TaskSetManager( } ef.exception + case e: ExecutorLostFailure if e.isNormalExit => + logInfo(s"Task $tid failed because while it was being computed, its executor" + + s" exited normally. Not marking the task as failed.") + None + case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others logWarning(failureReason) None @@ -722,10 +722,9 @@ private[spark] class TaskSetManager( put(info.executorId, clock.getTimeMillis()) sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics) addPendingTask(index) - if (!isZombie && state != TaskState.KILLED && !reason.isInstanceOf[TaskCommitDenied]) { - // If a task failed because its attempt to commit was denied, do not count this failure - // towards failing the stage. This is intended to prevent spurious stage failures in cases - // where many speculative tasks are launched and denied to commit. + if (!isZombie && state != TaskState.KILLED + && reason.isInstanceOf[TaskFailedReason] + && reason.asInstanceOf[TaskFailedReason].shouldEventuallyFailJob) { assert (null != failureReason) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { @@ -778,19 +777,7 @@ private[spark] class TaskSetManager( } /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */ - override def executorLost(execId: String, host: String) { - logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - - // Re-enqueue pending tasks for this host based on the status of the cluster. Note - // that it's okay if we add a task to the same queue twice (if it had multiple preferred - // locations), because dequeueTaskFromList will skip already-running tasks. - for (index <- getPendingTasksForExecutor(execId)) { - addPendingTask(index, readding = true) - } - for (index <- getPendingTasksForHost(host)) { - addPendingTask(index, readding = true) - } - + override def executorLost(execId: String, host: String, reason: ExecutorLossReason) { // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage, // and we are not using an external shuffle server which could serve the shuffle outputs. // The reason is the next stage wouldn't be able to fetch the data from this dead executor @@ -809,9 +796,12 @@ private[spark] class TaskSetManager( } } } - // Also re-enqueue any tasks that were running on the node for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(execId)) + val isNormalExit: Boolean = reason match { + case exited: ExecutorExited => exited.isNormalExit + case _ => false + } + handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, isNormalExit)) } // recalculate valid locality levels and waits when executor is lost recomputeLocality() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 06f5438433b6e..4652df32efa74 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.ExecutorLossReason import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -35,9 +36,13 @@ private[spark] object CoarseGrainedClusterMessages { case class KillTask(taskId: Long, executor: String, interruptThread: Boolean) extends CoarseGrainedClusterMessage + sealed trait RegisterExecutorResponse + case object RegisteredExecutor extends CoarseGrainedClusterMessage + with RegisterExecutorResponse case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage + with RegisterExecutorResponse // Executors to driver case class RegisterExecutor( @@ -70,7 +75,8 @@ private[spark] object CoarseGrainedClusterMessages { case object StopExecutors extends CoarseGrainedClusterMessage - case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage + case class RemoveExecutor(executorId: String, reason: ExecutorLossReason) + extends CoarseGrainedClusterMessage case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage @@ -92,6 +98,17 @@ private[spark] object CoarseGrainedClusterMessages { hostToLocalTaskCount: Map[String, Int]) extends CoarseGrainedClusterMessage + // Check if an executor was force-killed but for a normal reason. + // This could be the case if the executor is preempted, for instance. + case class GetExecutorLossReason(executorId: String) extends CoarseGrainedClusterMessage + case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage + // Used internally by executors to shut themselves down. + case object Shutdown extends CoarseGrainedClusterMessage + + // SPARK-10987: workaround for netty RPC issue; forces a connection from the driver back + // to the AM. + case object DriverHello extends CoarseGrainedClusterMessage + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5730a87f960a0..55a564b5c8eac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -26,6 +26,7 @@ import org.apache.spark.rpc._ import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.ENDPOINT_NAME import org.apache.spark.util.{ThreadUtils, SerializableBuffer, AkkaUtils, Utils} /** @@ -82,7 +83,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[RpcAddress, String] + protected val addressToExecutorId = new HashMap[RpcAddress, String] private val reviveThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread") @@ -128,6 +129,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { @@ -185,8 +187,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } override def onDisconnected(remoteAddress: RpcAddress): Unit = { - addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_, - "remote Rpc client disassociated")) + addressToExecutorId + .get(remoteAddress) + .foreach(removeExecutor(_, SlaveLost("remote Rpc client disassociated"))) } // Make fake resource offers on just one executor @@ -227,7 +230,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Remove a disconnected slave from the cluster - def removeExecutor(executorId: String, reason: String): Unit = { + def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { executorDataMap.get(executorId) match { case Some(executorInfo) => // This must be synchronized because variables mutated @@ -239,9 +242,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) - scheduler.executorLost(executorId, SlaveLost(reason)) + scheduler.executorLost(executorId, reason) listenerBus.post( - SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason)) + SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason.toString)) case None => logInfo(s"Asked to remove non-existent executor $executorId") } } @@ -263,8 +266,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // TODO (prashant) send conf instead of properties - driverEndpoint = rpcEnv.setupEndpoint( - CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties)) + driverEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties)) + } + + protected def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { + new DriverEndpoint(rpcEnv, properties) } def stopExecutors() { @@ -304,7 +310,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Called by subclasses when notified of a lost worker - def removeExecutor(executorId: String, reason: String) { + def removeExecutor(executorId: String, reason: ExecutorLossReason) { try { driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) } catch { @@ -432,6 +438,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (!replace) { doRequestTotalExecutors( numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + } else { + numPendingExecutors += knownExecutors.size } doKillExecutors(executorsToKill) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 0324c9dab910b..641638a77d5f5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -65,7 +65,9 @@ private[spark] class SimrSchedulerBackend( override def stop() { val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) val fs = FileSystem.get(conf) - fs.delete(new Path(driverFilePath), false) + if (!fs.delete(new Path(driverFilePath), false)) { + logWarning(s"error deleting ${driverFilePath}") + } super.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index bbe51b4a09a22..2625c3e7ac718 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -23,7 +23,8 @@ import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} -import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} +import org.apache.spark.scheduler._ import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( @@ -36,6 +37,9 @@ private[spark] class SparkDeploySchedulerBackend( private var client: AppClient = null private var stopping = false + private val launcherBackend = new LauncherBackend() { + override protected def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED) + } @volatile var shutdownCallback: SparkDeploySchedulerBackend => Unit = _ @volatile private var appId: String = _ @@ -47,6 +51,7 @@ private[spark] class SparkDeploySchedulerBackend( override def start() { super.start() + launcherBackend.connect() // The endpoint for executors to talk to us val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, @@ -87,24 +92,20 @@ private[spark] class SparkDeploySchedulerBackend( command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() + launcherBackend.setState(SparkAppHandle.State.SUBMITTED) waitForRegistration() + launcherBackend.setState(SparkAppHandle.State.RUNNING) } - override def stop() { - stopping = true - super.stop() - client.stop() - - val callback = shutdownCallback - if (callback != null) { - callback(this) - } + override def stop(): Unit = synchronized { + stop(SparkAppHandle.State.FINISHED) } override def connected(appId: String) { logInfo("Connected to Spark cluster with app ID " + appId) this.appId = appId notifyContext() + launcherBackend.setAppId(appId) } override def disconnected() { @@ -117,6 +118,7 @@ private[spark] class SparkDeploySchedulerBackend( override def dead(reason: String) { notifyContext() if (!stopping) { + launcherBackend.setState(SparkAppHandle.State.KILLED) logError("Application has been killed. Reason: " + reason) try { scheduler.error(reason) @@ -135,11 +137,11 @@ private[spark] class SparkDeploySchedulerBackend( override def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]) { val reason: ExecutorLossReason = exitStatus match { - case Some(code) => ExecutorExited(code) + case Some(code) => ExecutorExited(code, isNormalExit = true, message) case None => SlaveLost(message) } logInfo("Executor %s removed: %s".format(fullId, message)) - removeExecutor(fullId.split("/")(1), reason.toString) + removeExecutor(fullId.split("/")(1), reason) } override def sufficientResourcesRegistered(): Boolean = { @@ -188,4 +190,19 @@ private[spark] class SparkDeploySchedulerBackend( registrationBarrier.release() } + private def stop(finalState: SparkAppHandle.State): Unit = synchronized { + stopping = true + + launcherBackend.setState(finalState) + launcherBackend.close() + + super.stop() + client.stop() + + val callback = shutdownCallback + if (callback != null) { + callback(this) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 044f6288fabdd..38218b9c08fd8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -17,12 +17,13 @@ package org.apache.spark.scheduler.cluster +import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Future, ExecutionContext} import org.apache.spark.{Logging, SparkContext} import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.scheduler._ import org.apache.spark.ui.JettyUtils import org.apache.spark.util.{ThreadUtils, RpcUtils} @@ -43,8 +44,10 @@ private[spark] abstract class YarnSchedulerBackend( protected var totalExpectedExecutors = 0 - private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( - YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) + private val yarnSchedulerEndpoint = new YarnSchedulerEndpoint(rpcEnv) + + private val yarnSchedulerEndpointRef = rpcEnv.setupEndpoint( + YarnSchedulerBackend.ENDPOINT_NAME, yarnSchedulerEndpoint) private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) @@ -53,7 +56,7 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - yarnSchedulerEndpoint.askWithRetry[Boolean]( + yarnSchedulerEndpointRef.askWithRetry[Boolean]( RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) } @@ -61,7 +64,7 @@ private[spark] abstract class YarnSchedulerBackend( * Request that the ApplicationMaster kill the specified executors. */ override def doKillExecutors(executorIds: Seq[String]): Boolean = { - yarnSchedulerEndpoint.askWithRetry[Boolean](KillExecutors(executorIds)) + yarnSchedulerEndpointRef.askWithRetry[Boolean](KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -90,6 +93,41 @@ private[spark] abstract class YarnSchedulerBackend( } } + override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { + new YarnDriverEndpoint(rpcEnv, properties) + } + + /** + * Override the DriverEndpoint to add extra logic for the case when an executor is disconnected. + * This endpoint communicates with the executors and queries the AM for an executor's exit + * status when the executor is disconnected. + */ + private class YarnDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) + extends DriverEndpoint(rpcEnv, sparkProperties) { + + /** + * When onDisconnected is received at the driver endpoint, the superclass DriverEndpoint + * handles it by assuming the Executor was lost for a bad reason and removes the executor + * immediately. + * + * In YARN's case however it is crucial to talk to the application master and ask why the + * executor had exited. In particular, the executor may have exited due to the executor + * having been preempted. If the executor "exited normally" according to the application + * master then we pass that information down to the TaskSetManager to inform the + * TaskSetManager that tasks on that lost executor should not count towards a job failure. + * + * TODO there's a race condition where while we are querying the ApplicationMaster for + * the executor loss reason, there is the potential that tasks will be scheduled on + * the executor that failed. We should fix this by having this onDisconnected event + * also "blacklist" executors so that tasks are not assigned to them. + */ + override def onDisconnected(rpcAddress: RpcAddress): Unit = { + addressToExecutorId.get(rpcAddress).foreach { executorId => + yarnSchedulerEndpoint.handleExecutorDisconnectedFromDriver(executorId, rpcAddress) + } + } + } + /** * An [[RpcEndpoint]] that communicates with the ApplicationMaster. */ @@ -101,10 +139,39 @@ private[spark] abstract class YarnSchedulerBackend( ThreadUtils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) + private[YarnSchedulerBackend] def handleExecutorDisconnectedFromDriver( + executorId: String, + executorRpcAddress: RpcAddress): Unit = { + amEndpoint match { + case Some(am) => + val lossReasonRequest = GetExecutorLossReason(executorId) + val future = am.ask[ExecutorLossReason](lossReasonRequest, askTimeout) + future onSuccess { + case reason: ExecutorLossReason => { + driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) + } + } + future onFailure { + case NonFatal(e) => { + logWarning(s"Attempted to get executor loss reason" + + s" for executor id ${executorId} at RPC address ${executorRpcAddress}," + + s" but got no response. Marking as slave lost.", e) + driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, SlaveLost())) + } + case t => throw t + } + case None => + logWarning("Attempted to check for an executor loss reason" + + " before the AM has registered!") + } + } + override def receive: PartialFunction[Any, Unit] = { case RegisterClusterManager(am) => logInfo(s"ApplicationMaster registered as $am") - amEndpoint = Some(am) + amEndpoint = Option(am) + // See SPARK-10987. + am.send(DriverHello) case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) @@ -113,6 +180,7 @@ private[spark] abstract class YarnSchedulerBackend( removeExecutor(executorId, reason) } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case r: RequestExecutors => amEndpoint match { @@ -143,7 +211,6 @@ private[spark] abstract class YarnSchedulerBackend( logWarning("Attempted to kill executors before the AM has registered!") context.reply(false) } - } override def onDisconnected(remoteAddress: RpcAddress): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 415c54cec9e4c..879e907d7932b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -30,7 +30,7 @@ import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.RpcAddress -import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException, TaskState} @@ -139,7 +139,12 @@ private[spark] class CoarseMesosSchedulerBackend( override def start() { super.start() val driver = createSchedulerDriver( - master, CoarseMesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf) + master, + CoarseMesosSchedulerBackend.this, + sc.sparkUser, + sc.appName, + sc.conf, + sc.ui.map(_.appUIAddress)) startScheduler(driver) } @@ -234,6 +239,10 @@ private[spark] class CoarseMesosSchedulerBackend( markRegistered() } + override def sufficientResourcesRegistered(): Boolean = { + totalCoresAcquired >= maxCores * minRegisteredRatio + } + override def disconnected(d: SchedulerDriver) {} override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} @@ -386,7 +395,7 @@ private[spark] class CoarseMesosSchedulerBackend( stateLock.synchronized { if (slaveIdsWithExecutors.contains(slaveId) && taskIdToSlaveId.contains(executorId)) { taskIdToSlaveId.remove(executorId) - removeExecutor(executorId, reason) + removeExecutor(executorId, SlaveLost(reason)) val newCount = slaveIdsWithExecutors(slaveId) - 1 if (newCount == 0) { slaveIdsWithExecutors.remove(slaveId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 07da9242b9922..a6d9374eb9e8c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -29,7 +29,6 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.{Scheduler, SchedulerDriver} - import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem @@ -375,21 +374,20 @@ private[spark] class MesosClusterScheduler( val executorOpts = desc.schedulerProperties.map { case (k, v) => s"-D$k=$v" }.mkString(" ") envBuilder.addVariables( Variable.newBuilder().setName("SPARK_EXECUTOR_OPTS").setValue(executorOpts)) - val cmdOptions = generateCmdOption(desc).mkString(" ") val dockerDefined = desc.schedulerProperties.contains("spark.mesos.executor.docker.image") val executorUri = desc.schedulerProperties.get("spark.executor.uri") .orElse(desc.command.environment.get("SPARK_EXECUTOR_URI")) - val appArguments = desc.command.arguments.mkString(" ") - val (executable, jar) = if (dockerDefined) { + // Gets the path to run spark-submit, and the path to the Mesos sandbox. + val (executable, sandboxPath) = if (dockerDefined) { // Application jar is automatically downloaded in the mounted sandbox by Mesos, // and the path to the mounted volume is stored in $MESOS_SANDBOX env variable. - ("./bin/spark-submit", s"$$MESOS_SANDBOX/${desc.jarUrl.split("/").last}") + ("./bin/spark-submit", "$MESOS_SANDBOX") } else if (executorUri.isDefined) { builder.addUris(CommandInfo.URI.newBuilder().setValue(executorUri.get).build()) val folderBasename = executorUri.get.split('/').last.split('.').head val cmdExecutable = s"cd $folderBasename*; $prefixEnv bin/spark-submit" - val cmdJar = s"../${desc.jarUrl.split("/").last}" - (cmdExecutable, cmdJar) + // Sandbox path points to the parent folder as we chdir into the folderBasename. + (cmdExecutable, "..") } else { val executorSparkHome = desc.schedulerProperties.get("spark.mesos.executor.home") .orElse(conf.getOption("spark.home")) @@ -398,30 +396,50 @@ private[spark] class MesosClusterScheduler( throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") } val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getCanonicalPath - val cmdJar = desc.jarUrl.split("/").last - (cmdExecutable, cmdJar) + // Sandbox points to the current directory by default with Mesos. + (cmdExecutable, ".") } - builder.setValue(s"$executable $cmdOptions $jar $appArguments") + val primaryResource = new File(sandboxPath, desc.jarUrl.split("/").last).toString() + val cmdOptions = generateCmdOption(desc, sandboxPath).mkString(" ") + val appArguments = desc.command.arguments.mkString(" ") + builder.setValue(s"$executable $cmdOptions $primaryResource $appArguments") builder.setEnvironment(envBuilder.build()) conf.getOption("spark.mesos.uris").map { uris => setupUris(uris, builder) } + desc.schedulerProperties.get("spark.mesos.uris").map { uris => + setupUris(uris, builder) + } + desc.schedulerProperties.get("spark.submit.pyFiles").map { pyFiles => + setupUris(pyFiles, builder) + } builder.build() } - private def generateCmdOption(desc: MesosDriverDescription): Seq[String] = { + private def generateCmdOption(desc: MesosDriverDescription, sandboxPath: String): Seq[String] = { var options = Seq( "--name", desc.schedulerProperties("spark.app.name"), - "--class", desc.command.mainClass, "--master", s"mesos://${conf.get("spark.master")}", "--driver-cores", desc.cores.toString, "--driver-memory", s"${desc.mem}M") + + // Assume empty main class means we're running python + if (!desc.command.mainClass.equals("")) { + options ++= Seq("--class", desc.command.mainClass) + } + desc.schedulerProperties.get("spark.executor.memory").map { v => options ++= Seq("--executor-memory", v) } desc.schedulerProperties.get("spark.cores.max").map { v => options ++= Seq("--total-executor-cores", v) } + desc.schedulerProperties.get("spark.submit.pyFiles").map { pyFiles => + val formattedFiles = pyFiles.split(",") + .map { path => new File(sandboxPath, path.split("/").last).toString() } + .mkString(",") + options ++= Seq("--py-files", formattedFiles) + } options } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 2e424054be785..6196176c7cc33 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -32,7 +32,6 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils - /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a * separate Mesos task, allowing multiple applications to share cluster nodes both in space (tasks @@ -69,7 +68,12 @@ private[spark] class MesosSchedulerBackend( override def start() { classLoader = Thread.currentThread.getContextClassLoader val driver = createSchedulerDriver( - master, MesosSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf) + master, + MesosSchedulerBackend.this, + sc.sparkUser, + sc.appName, + sc.conf, + sc.ui.map(_.appUIAddress)) startScheduler(driver) } @@ -127,7 +131,7 @@ private[spark] class MesosSchedulerBackend( } val builder = MesosExecutorInfo.newBuilder() val (resourcesAfterCpu, usedCpuResources) = - partitionResources(availableResources, "cpus", scheduler.CPUS_PER_TASK) + partitionResources(availableResources, "cpus", mesosExecutorCores) val (resourcesAfterMem, usedMemResources) = partitionResources(resourcesAfterCpu.asJava, "mem", calculateTotalMemory(sc)) @@ -390,7 +394,7 @@ private[spark] class MesosSchedulerBackend( slaveId: SlaveID, status: Int) { logInfo("Executor lost: %s, marking slave %s as lost".format(executorId.getValue, slaveId.getValue)) - recordSlaveLost(d, slaveId, ExecutorExited(status)) + recordSlaveLost(d, slaveId, ExecutorExited(status, isNormalExit = false)) } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 4d48fcfea44e7..c633d860ae6e5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -24,6 +24,7 @@ import java.nio.ByteBuffer import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -103,6 +104,9 @@ private[spark] class LocalBackend( private var localEndpoint: RpcEndpointRef = null private val userClassPath = getUserClasspath(conf) private val listenerBus = scheduler.sc.listenerBus + private val launcherBackend = new LauncherBackend() { + override def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED) + } /** * Returns a list of URLs representing the user classpath. @@ -114,6 +118,8 @@ private[spark] class LocalBackend( userClassPathStr.map(_.split(File.pathSeparator)).toSeq.flatten.map(new File(_).toURI.toURL) } + launcherBackend.connect() + override def start() { val rpcEnv = SparkEnv.get.rpcEnv val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores) @@ -122,10 +128,12 @@ private[spark] class LocalBackend( System.currentTimeMillis, executorEndpoint.localExecutorId, new ExecutorInfo(executorEndpoint.localExecutorHostname, totalCores, Map.empty))) + launcherBackend.setAppId(appId) + launcherBackend.setState(SparkAppHandle.State.RUNNING) } override def stop() { - localEndpoint.ask(StopExecutor) + stop(SparkAppHandle.State.FINISHED) } override def reviveOffers() { @@ -145,4 +153,13 @@ private[spark] class LocalBackend( override def applicationId(): String = appId + private def stop(finalState: SparkAppHandle.State): Unit = { + localEndpoint.ask(StopExecutor) + try { + launcherBackend.setState(finalState) + } finally { + launcherBackend.close() + } + } + } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index b977711e7d5ad..c5195c1143a8f 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -35,7 +35,6 @@ import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, Roaring import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.HttpBroadcast -import org.apache.spark.network.nio.{GetBlock, GotBlock, PutBlock} import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ @@ -362,9 +361,6 @@ private[serializer] object KryoSerializer { private val toRegister: Seq[Class[_]] = Seq( ByteBuffer.allocate(1).getClass, classOf[StorageLevel], - classOf[PutBlock], - classOf[GotBlock], - classOf[GetBlock], classOf[CompressedMapStatus], classOf[HighlyCompressedMapStatus], classOf[RoaringBitmap], diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala similarity index 94% rename from core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala rename to core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 0c8f08f0f3b1b..7c3e2b5a3703b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -15,16 +15,19 @@ * limitations under the License. */ -package org.apache.spark.shuffle.hash +package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -private[spark] class HashShuffleReader[K, C]( +/** + * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by + * requesting them from other nodes' block stores. + */ +private[spark] class BlockStoreShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, @@ -33,9 +36,6 @@ private[spark] class HashShuffleReader[K, C]( mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] with Logging { - require(endPartition == startPartition + 1, - "Hash shuffle currently only supports fetching one partition") - private val dep = handle.dependency /** Read the combined key-values for this reduce task */ @@ -44,7 +44,7 @@ private[spark] class HashShuffleReader[K, C]( context, blockManager.shuffleClient, blockManager, - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition), + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index c057de9b3f4df..cd253a78c2b19 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -17,9 +17,7 @@ package org.apache.spark.shuffle -import java.io.File import java.util.concurrent.ConcurrentLinkedQueue -import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ @@ -28,10 +26,8 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.FileShuffleBlockResolver.ShuffleFileGroup import org.apache.spark.storage._ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} -import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -43,24 +39,7 @@ private[spark] trait ShuffleWriterGroup { /** * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file - * per reducer (this set of files is called a ShuffleFileGroup). - * - * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle - * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer - * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle - * files, it releases them for another task. - * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple: - * - shuffleId: The unique id given to the entire shuffle stage. - * - bucketId: The id of the output partition (i.e., reducer id) - * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a - * time owns a particular fileId, and this id is returned to a pool when the task finishes. - * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length) - * that specifies where in a given file the actual block data is located. - * - * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping - * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for - * each block stored in each file. In order to find the location of a shuffle block, we search the - * files within a ShuffleFileGroups associated with the block's reducer. + * per reducer. */ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getHashBasedShuffleBlockData(). @@ -71,26 +50,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private lazy val blockManager = SparkEnv.get.blockManager - // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. - // TODO: Remove this once the shuffle file consolidation feature is stable. - private val consolidateShuffleFiles = - conf.getBoolean("spark.shuffle.consolidateFiles", false) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** - * Contains all the state related to a particular shuffle. This includes a pool of unused - * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle. + * Contains all the state related to a particular shuffle. */ - private class ShuffleState(val numBuckets: Int) { - val nextFileId = new AtomicInteger(0) - val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() - val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() - + private class ShuffleState(val numReducers: Int) { /** * The mapIds of all map tasks completed on this Executor for this shuffle. - * NB: This is only populated if consolidateShuffleFiles is FALSE. We don't need it otherwise. */ val completedMapTasks = new ConcurrentLinkedQueue[Int]() } @@ -104,24 +72,16 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) * Get a ShuffleWriterGroup for the given map task, which will register it as complete * when the writers are closed successfully */ - def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, + def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer, writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { new ShuffleWriterGroup { - shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers)) private val shuffleState = shuffleStates(shuffleId) - private var fileGroup: ShuffleFileGroup = null val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() - val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) { - fileGroup = getUnusedFileGroup() - Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, - writeMetrics) - } - } else { - Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => + val writers: Array[DiskBlockObjectWriter] = { + Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) val blockFile = blockManager.diskBlockManager.getFile(blockId) // Because of previous failures, the shuffle file may already exist on this machine. @@ -142,58 +102,14 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) override def releaseWriters(success: Boolean) { - if (consolidateShuffleFiles) { - if (success) { - val offsets = writers.map(_.fileSegment().offset) - val lengths = writers.map(_.fileSegment().length) - fileGroup.recordMapOutput(mapId, offsets, lengths) - } - recycleFileGroup(fileGroup) - } else { - shuffleState.completedMapTasks.add(mapId) - } - } - - private def getUnusedFileGroup(): ShuffleFileGroup = { - val fileGroup = shuffleState.unusedFileGroups.poll() - if (fileGroup != null) fileGroup else newFileGroup() - } - - private def newFileGroup(): ShuffleFileGroup = { - val fileId = shuffleState.nextFileId.getAndIncrement() - val files = Array.tabulate[File](numBuckets) { bucketId => - val filename = physicalFileName(shuffleId, bucketId, fileId) - blockManager.diskBlockManager.getFile(filename) - } - val fileGroup = new ShuffleFileGroup(shuffleId, fileId, files) - shuffleState.allFileGroups.add(fileGroup) - fileGroup - } - - private def recycleFileGroup(group: ShuffleFileGroup) { - shuffleState.unusedFileGroups.add(group) + shuffleState.completedMapTasks.add(mapId) } } } override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { - if (consolidateShuffleFiles) { - // Search all file groups associated with this shuffle. - val shuffleState = shuffleStates(blockId.shuffleId) - val iter = shuffleState.allFileGroups.iterator - while (iter.hasNext) { - val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId) - if (segmentOpt.isDefined) { - val segment = segmentOpt.get - return new FileSegmentManagedBuffer( - transportConf, segment.file, segment.offset, segment.length) - } - } - throw new IllegalStateException("Failed to find shuffle block: " + blockId) - } else { - val file = blockManager.diskBlockManager.getFile(blockId) - new FileSegmentManagedBuffer(transportConf, file, 0, file.length) - } + val file = blockManager.diskBlockManager.getFile(blockId) + new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } /** Remove all the blocks / files and metadata related to a particular shuffle. */ @@ -209,16 +125,11 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { shuffleStates.get(shuffleId) match { case Some(state) => - if (consolidateShuffleFiles) { - for (fileGroup <- state.allFileGroups.asScala; - file <- fileGroup.files) { - file.delete() - } - } else { - for (mapId <- state.completedMapTasks.asScala; - reduceId <- 0 until state.numBuckets) { - val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) - blockManager.diskBlockManager.getFile(blockId).delete() + for (mapId <- state.completedMapTasks.asScala; reduceId <- 0 until state.numReducers) { + val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) + val file = blockManager.diskBlockManager.getFile(blockId) + if (!file.delete()) { + logWarning(s"Error deleting ${file.getPath()}") } } logInfo("Deleted all files for shuffle " + shuffleId) @@ -229,10 +140,6 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) } } - private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { - "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) - } - private def cleanup(cleanupTime: Long) { shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) } @@ -241,59 +148,3 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) metadataCleaner.cancel() } } - -private[spark] object FileShuffleBlockResolver { - /** - * A group of shuffle files, one per reducer. - * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. - */ - private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) { - private var numBlocks: Int = 0 - - /** - * Stores the absolute index of each mapId in the files of this group. For instance, - * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. - */ - private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() - - /** - * Stores consecutive offsets and lengths of blocks into each reducer file, ordered by - * position in the file. - * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every - * reducer. - */ - private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() - } - private val blockLengthsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() - } - - def apply(bucketId: Int): File = files(bucketId) - - def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { - assert(offsets.length == lengths.length) - mapIdToIndex(mapId) = numBlocks - numBlocks += 1 - for (i <- 0 until offsets.length) { - blockOffsetsByReducer(i) += offsets(i) - blockLengthsByReducer(i) += lengths(i) - } - } - - /** Returns the FileSegment associated with the given map task, or None if no entry exists. */ - def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = { - val file = files(reducerId) - val blockOffsets = blockOffsetsByReducer(reducerId) - val blockLengths = blockLengthsByReducer(reducerId) - val index = mapIdToIndex.getOrElse(mapId, -1) - if (index >= 0) { - val offset = blockOffsets(index) - val length = blockLengths(index) - Some(new FileSegment(file, offset, length)) - } else { - None - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d0163d326dba7..5e4c2b5d0a5c4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -21,7 +21,7 @@ import java.io._ import com.google.common.io.ByteStreams -import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.{SparkConf, SparkEnv, Logging} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.storage._ @@ -40,7 +40,8 @@ import IndexShuffleBlockResolver.NOOP_REDUCE_ID */ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData(). -private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver { +private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver + with Logging { private lazy val blockManager = SparkEnv.get.blockManager @@ -60,12 +61,16 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB def removeDataByMap(shuffleId: Int, mapId: Int): Unit = { var file = getDataFile(shuffleId, mapId) if (file.exists()) { - file.delete() + if (!file.delete()) { + logWarning(s"Error deleting data ${file.getPath()}") + } } file = getIndexFile(shuffleId, mapId) if (file.exists()) { - file.delete() + if (!file.delete()) { + logWarning(s"Error deleting index ${file.getPath()}") + } } } @@ -114,9 +119,8 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB } private[spark] object IndexShuffleBlockResolver { - // No-op reduce ID used in interactions with disk store and DiskBlockObjectWriter. + // No-op reduce ID used in interactions with disk store. // The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort // shuffle outputs for several reduces are glommed into a single file. - // TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId. val NOOP_REDUCE_ID = 0 } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index a0d8abc2eecb3..9bd18da47f1a2 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -18,11 +18,14 @@ package org.apache.spark.shuffle import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import com.google.common.annotations.VisibleForTesting +import org.apache.spark._ +import org.apache.spark.memory.{StaticMemoryManager, MemoryManager} +import org.apache.spark.storage.{BlockId, BlockStatus} import org.apache.spark.unsafe.array.ByteArrayMethods -import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext} /** * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling @@ -35,17 +38,17 @@ import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext} * If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever - * this set changes. This is all done by synchronizing access on "this" to mutate state and using - * wait() and notifyAll() to signal changes. + * this set changes. This is all done by synchronizing access to `memoryManager` to mutate state + * and using wait() and notifyAll() to signal changes. * * Use `ShuffleMemoryManager.create()` factory method to create a new instance. * - * @param maxMemory total amount of memory available for execution, in bytes. + * @param memoryManager the interface through which this manager acquires execution memory * @param pageSizeBytes number of bytes for each page, by default. */ private[spark] class ShuffleMemoryManager protected ( - val maxMemory: Long, + memoryManager: MemoryManager, val pageSizeBytes: Long) extends Logging { @@ -63,7 +66,7 @@ class ShuffleMemoryManager protected ( * total memory pool (where N is the # of active tasks) before it is forced to spill. This can * happen if the number of tasks increases but an older task had a lot of memory already. */ - def tryToAcquire(numBytes: Long): Long = synchronized { + def tryToAcquire(numBytes: Long): Long = memoryManager.synchronized { val taskAttemptId = currentTaskAttemptId() assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) @@ -71,65 +74,89 @@ class ShuffleMemoryManager protected ( // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire if (!taskMemory.contains(taskAttemptId)) { taskMemory(taskAttemptId) = 0L - notifyAll() // Will later cause waiting tasks to wake up and check numThreads again + // This will later cause waiting tasks to wake up and check numTasks again + memoryManager.notifyAll() } // Keep looping until we're either sure that we don't want to grant this request (because this // task would have more than 1 / numActiveTasks of the memory) or we have enough free // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). + // TODO: simplify this to limit each task to its own slot while (true) { val numActiveTasks = taskMemory.keys.size val curMem = taskMemory(taskAttemptId) + val maxMemory = memoryManager.maxExecutionMemory val freeMemory = maxMemory - taskMemory.values.sum // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; // don't let it be negative val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem)) + // Only give it as much memory as is free, which might be none if it reached 1 / numTasks + val toGrant = math.min(maxToGrant, freeMemory) if (curMem < maxMemory / (2 * numActiveTasks)) { // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; // if we can't give it this much now, wait for other tasks to free up memory // (this happens if older tasks allocated lots of memory before N grew) if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) { - val toGrant = math.min(maxToGrant, freeMemory) - taskMemory(taskAttemptId) += toGrant - return toGrant + return acquire(toGrant) } else { logInfo( s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") - wait() + memoryManager.wait() } } else { - // Only give it as much memory as is free, which might be none if it reached 1 / numThreads - val toGrant = math.min(maxToGrant, freeMemory) - taskMemory(taskAttemptId) += toGrant - return toGrant + return acquire(toGrant) } } 0L // Never reached } + /** + * Acquire N bytes of execution memory from the memory manager for the current task. + * @return number of bytes actually acquired (<= N). + */ + private def acquire(numBytes: Long): Long = memoryManager.synchronized { + val taskAttemptId = currentTaskAttemptId() + val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + val acquired = memoryManager.acquireExecutionMemory(numBytes, evictedBlocks) + // Register evicted blocks, if any, with the active task metrics + // TODO: just do this in `acquireExecutionMemory` (SPARK-10985) + Option(TaskContext.get()).foreach { tc => + val metrics = tc.taskMetrics() + val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) + metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) + } + taskMemory(taskAttemptId) += acquired + acquired + } + /** Release numBytes bytes for the current task. */ - def release(numBytes: Long): Unit = synchronized { + def release(numBytes: Long): Unit = memoryManager.synchronized { val taskAttemptId = currentTaskAttemptId() val curMem = taskMemory.getOrElse(taskAttemptId, 0L) if (curMem < numBytes) { throw new SparkException( - s"Internal error: release called on ${numBytes} bytes but task only has ${curMem}") + s"Internal error: release called on $numBytes bytes but task only has $curMem") + } + if (taskMemory.contains(taskAttemptId)) { + taskMemory(taskAttemptId) -= numBytes + memoryManager.releaseExecutionMemory(numBytes) } - taskMemory(taskAttemptId) -= numBytes - notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed + memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed } /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */ - def releaseMemoryForThisTask(): Unit = synchronized { + def releaseMemoryForThisTask(): Unit = memoryManager.synchronized { val taskAttemptId = currentTaskAttemptId() - taskMemory.remove(taskAttemptId) - notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed + taskMemory.remove(taskAttemptId).foreach { numBytes => + memoryManager.releaseExecutionMemory(numBytes) + } + memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed } /** Returns the memory consumption, in bytes, for the current task */ - def getMemoryConsumptionForThisTask(): Long = synchronized { + def getMemoryConsumptionForThisTask(): Long = memoryManager.synchronized { val taskAttemptId = currentTaskAttemptId() taskMemory.getOrElse(taskAttemptId, 0L) } @@ -138,30 +165,28 @@ class ShuffleMemoryManager protected ( private[spark] object ShuffleMemoryManager { - def create(conf: SparkConf, numCores: Int): ShuffleMemoryManager = { - val maxMemory = ShuffleMemoryManager.getMaxMemory(conf) + def create( + conf: SparkConf, + memoryManager: MemoryManager, + numCores: Int): ShuffleMemoryManager = { + val maxMemory = memoryManager.maxExecutionMemory val pageSize = ShuffleMemoryManager.getPageSize(conf, maxMemory, numCores) - new ShuffleMemoryManager(maxMemory, pageSize) + new ShuffleMemoryManager(memoryManager, pageSize) } + /** + * Create a dummy [[ShuffleMemoryManager]] with the specified capacity and page size. + */ def create(maxMemory: Long, pageSizeBytes: Long): ShuffleMemoryManager = { - new ShuffleMemoryManager(maxMemory, pageSizeBytes) + val conf = new SparkConf + val memoryManager = new StaticMemoryManager( + conf, maxExecutionMemory = maxMemory, maxStorageMemory = Long.MaxValue) + new ShuffleMemoryManager(memoryManager, pageSizeBytes) } @VisibleForTesting def createForTesting(maxMemory: Long): ShuffleMemoryManager = { - new ShuffleMemoryManager(maxMemory, 4 * 1024 * 1024) - } - - /** - * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction - * of the memory pool and a safety factor since collections can sometimes grow bigger than - * the size we target before we estimate their sizes again. - */ - private def getMaxMemory(conf: SparkConf): Long = { - val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2) - val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8) - (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong + create(maxMemory, 4 * 1024 * 1024) } /** @@ -177,7 +202,6 @@ private[spark] object ShuffleMemoryManager { val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case val safetyFactor = 16 - // TODO(davies): don't round to next power of 2 val size = ByteArrayMethods.nextPowerOf2(maxMemory / cores / safetyFactor) val default = math.min(maxPageSize, math.max(minPageSize, size)) conf.getSizeAsBytes("spark.buffer.pageSize", default) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index c089088f409dd..d2e2fc4c110a7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -24,7 +24,13 @@ import org.apache.spark.shuffle._ * A ShuffleManager using hashing, that creates one output file per reduce partition on each * mapper (possibly reusing these across waves of tasks). */ -private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager { +private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } private val fileShuffleBlockResolver = new FileShuffleBlockResolver(conf) @@ -45,7 +51,7 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index d7fab351ca3b8..1105167d39d8d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -19,14 +19,67 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap -import org.apache.spark.{SparkConf, TaskContext, ShuffleDependency} +import org.apache.spark._ +import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.hash.HashShuffleReader -private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager { +/** + * In sort-based shuffle, incoming records are sorted according to their target partition ids, then + * written to a single map output file. Reducers fetch contiguous regions of this file in order to + * read their portion of the map output. In cases where the map output data is too large to fit in + * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged + * to produce the final output file. + * + * Sort-based shuffle has two different write paths for producing its map output files: + * + * - Serialized sorting: used when all three of the following conditions hold: + * 1. The shuffle dependency specifies no aggregation or output ordering. + * 2. The shuffle serializer supports relocation of serialized values (this is currently + * supported by KryoSerializer and Spark SQL's custom serializers). + * 3. The shuffle produces fewer than 16777216 output partitions. + * - Deserialized sorting: used to handle all other cases. + * + * ----------------------- + * Serialized sorting mode + * ----------------------- + * + * In the serialized sorting mode, incoming records are serialized as soon as they are passed to the + * shuffle writer and are buffered in a serialized form during sorting. This write path implements + * several optimizations: + * + * - Its sort operates on serialized binary data rather than Java objects, which reduces memory + * consumption and GC overheads. This optimization requires the record serializer to have certain + * properties to allow serialized records to be re-ordered without requiring deserialization. + * See SPARK-4550, where this optimization was first proposed and implemented, for more details. + * + * - It uses a specialized cache-efficient sorter ([[ShuffleExternalSorter]]) that sorts + * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per + * record in the sorting array, this fits more of the array into cache. + * + * - The spill merging procedure operates on blocks of serialized records that belong to the same + * partition and does not need to deserialize records during the merge. + * + * - When the spill compression codec supports concatenation of compressed data, the spill merge + * simply concatenates the serialized and compressed spill partitions to produce the final output + * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used + * and avoids the need to allocate decompression or copying buffers during the merge. + * + * For more details on these optimizations, see SPARK-7081. + */ +private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { - private val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf) - private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]() + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } + + /** + * A mapping from shuffle ids to the number of mappers producing output for those shuffles. + */ + private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** * Register a shuffle with the manager and obtain a handle for it to pass to tasks. @@ -35,7 +88,22 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - new BaseShuffleHandle(shuffleId, numMaps, dependency) + if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need map-side aggregation, then write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleHandle[K, V]( + shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { + // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: + new SerializedShuffleHandle[K, V]( + shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + // Otherwise, buffer map outputs in a deserialized form: + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } } /** @@ -47,38 +115,114 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { - // We currently use the same block store shuffle fetcher as the hash-based shuffle. - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } /** Get a writer for a given partition. Called on executors by map tasks. */ - override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) - : ShuffleWriter[K, V] = { - val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]] - shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps) - new SortShuffleWriter( - shuffleBlockResolver, baseShuffleHandle, mapId, context) + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext): ShuffleWriter[K, V] = { + numMapsForShuffle.putIfAbsent( + handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) + val env = SparkEnv.get + handle match { + case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => + new UnsafeShuffleWriter( + env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + context.taskMemoryManager(), + env.shuffleMemoryManager, + unsafeShuffleHandle, + mapId, + context, + env.conf) + case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => + new BypassMergeSortShuffleWriter( + env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + bypassMergeSortHandle, + mapId, + context, + env.conf) + case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => + new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) + } } /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - if (shuffleMapNumber.containsKey(shuffleId)) { - val numMaps = shuffleMapNumber.remove(shuffleId) - (0 until numMaps).map{ mapId => + Option(numMapsForShuffle.remove(shuffleId)).foreach { numMaps => + (0 until numMaps).foreach { mapId => shuffleBlockResolver.removeDataByMap(shuffleId, mapId) } } true } - override val shuffleBlockResolver: IndexShuffleBlockResolver = { - indexShuffleBlockResolver - } - /** Shut down this ShuffleManager. */ override def stop(): Unit = { shuffleBlockResolver.stop() } } + +private[spark] object SortShuffleManager extends Logging { + + /** + * The maximum number of shuffle output partitions that SortShuffleManager supports when + * buffering map outputs in a serialized form. This is an extreme defensive programming measure, + * since it's extremely unlikely that a single shuffle produces over 16 million output partitions. + * */ + val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE = + PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 + + /** + * Helper method for determining whether a shuffle should use an optimized serialized shuffle + * path or whether it should fall back to the original path that operates on deserialized objects. + */ + def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = { + val shufId = dependency.shuffleId + val numPartitions = dependency.partitioner.numPartitions + val serializer = Serializer.getSerializer(dependency.serializer) + if (!serializer.supportsRelocationOfSerializedObjects) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " + + s"${serializer.getClass.getName}, does not support object relocation") + false + } else if (dependency.aggregator.isDefined) { + log.debug( + s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined") + false + } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { + log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " + + s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions") + false + } else { + log.debug(s"Can use serialized shuffle for shuffle $shufId") + true + } + } +} + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the + * serialized shuffle. + */ +private[spark] class SerializedShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the + * bypass merge sort shuffle path. + */ +private[spark] class BypassMergeSortShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 5865e7640c1cf..bbd9c1ab53cd8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -20,7 +20,6 @@ package org.apache.spark.shuffle.sort import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus -import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -36,7 +35,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private val blockManager = SparkEnv.get.blockManager - private var sorter: SortShuffleFileWriter[K, V] = null + private var sorter: ExternalSorter[K, V, _] = null // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -54,15 +53,6 @@ private[spark] class SortShuffleWriter[K, V, C]( require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") new ExternalSorter[K, V, C]( dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - } else if (SortShuffleWriter.shouldBypassMergeSort( - SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) { - // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't - // need local aggregation and sorting, write numPartitions files directly and just concatenate - // them at the end. This avoids doing serialization and deserialization twice to merge - // together the spilled files, which would happen with the normal code path. The downside is - // having multiple files open at a time and thus more memory allocated to buffers. - new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner, - writeMetrics, Serializer.getSerializer(dep.serializer)) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side @@ -111,12 +101,14 @@ private[spark] class SortShuffleWriter[K, V, C]( } private[spark] object SortShuffleWriter { - def shouldBypassMergeSort( - conf: SparkConf, - numPartitions: Int, - aggregator: Option[Aggregator[_, _, _]], - keyOrdering: Option[Ordering[_]]): Boolean = { - val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty + def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = { + // We cannot bypass sorting if we need to do map-side aggregation. + if (dep.mapSideCombine) { + require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") + false + } else { + val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + dep.partitioner.numPartitions <= bypassMergeThreshold + } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala deleted file mode 100644 index df7bbd64247dd..0000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.unsafe - -import java.util.Collections -import java.util.concurrent.ConcurrentHashMap - -import org.apache.spark._ -import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.sort.SortShuffleManager - -/** - * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle. - */ -private[spark] class UnsafeShuffleHandle[K, V]( - shuffleId: Int, - numMaps: Int, - dependency: ShuffleDependency[K, V, V]) - extends BaseShuffleHandle(shuffleId, numMaps, dependency) { -} - -private[spark] object UnsafeShuffleManager extends Logging { - - /** - * The maximum number of shuffle output partitions that UnsafeShuffleManager supports. - */ - val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 - - /** - * Helper method for determining whether a shuffle should use the optimized unsafe shuffle - * path or whether it should fall back to the original sort-based shuffle. - */ - def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = { - val shufId = dependency.shuffleId - val serializer = Serializer.getSerializer(dependency.serializer) - if (!serializer.supportsRelocationOfSerializedObjects) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " + - s"${serializer.getClass.getName}, does not support object relocation") - false - } else if (dependency.aggregator.isDefined) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined") - false - } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) { - log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " + - s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions") - false - } else { - log.debug(s"Can use UnsafeShuffle for shuffle $shufId") - true - } - } -} - -/** - * A shuffle implementation that uses directly-managed memory to implement several performance - * optimizations for certain types of shuffles. In cases where the new performance optimizations - * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those - * shuffles. - * - * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold: - * - * - The shuffle dependency specifies no aggregation or output ordering. - * - The shuffle serializer supports relocation of serialized values (this is currently supported - * by KryoSerializer and Spark SQL's custom serializers). - * - The shuffle produces fewer than 16777216 output partitions. - * - No individual record is larger than 128 MB when serialized. - * - * In addition, extra spill-merging optimizations are automatically applied when the shuffle - * compression codec supports concatenation of serialized streams. This is currently supported by - * Spark's LZF serializer. - * - * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager. - * In sort-based shuffle, incoming records are sorted according to their target partition ids, then - * written to a single map output file. Reducers fetch contiguous regions of this file in order to - * read their portion of the map output. In cases where the map output data is too large to fit in - * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged - * to produce the final output file. - * - * UnsafeShuffleManager optimizes this process in several ways: - * - * - Its sort operates on serialized binary data rather than Java objects, which reduces memory - * consumption and GC overheads. This optimization requires the record serializer to have certain - * properties to allow serialized records to be re-ordered without requiring deserialization. - * See SPARK-4550, where this optimization was first proposed and implemented, for more details. - * - * - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts - * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per - * record in the sorting array, this fits more of the array into cache. - * - * - The spill merging procedure operates on blocks of serialized records that belong to the same - * partition and does not need to deserialize records during the merge. - * - * - When the spill compression codec supports concatenation of compressed data, the spill merge - * simply concatenates the serialized and compressed spill partitions to produce the final output - * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used - * and avoids the need to allocate decompression or copying buffers during the merge. - * - * For more details on UnsafeShuffleManager's design, see SPARK-7081. - */ -private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { - - if (!conf.getBoolean("spark.shuffle.spill", true)) { - logWarning( - "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " + - "manager; its optimized shuffles will continue to spill to disk when necessary.") - } - - private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) - private[this] val shufflesThatFellBackToSortShuffle = - Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]()) - private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]() - - /** - * Register a shuffle with the manager and obtain a handle for it to pass to tasks. - */ - override def registerShuffle[K, V, C]( - shuffleId: Int, - numMaps: Int, - dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) { - new UnsafeShuffleHandle[K, V]( - shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) - } else { - new BaseShuffleHandle(shuffleId, numMaps, dependency) - } - } - - /** - * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). - * Called on executors by reduce tasks. - */ - override def getReader[K, C]( - handle: ShuffleHandle, - startPartition: Int, - endPartition: Int, - context: TaskContext): ShuffleReader[K, C] = { - sortShuffleManager.getReader(handle, startPartition, endPartition, context) - } - - /** Get a writer for a given partition. Called on executors by map tasks. */ - override def getWriter[K, V]( - handle: ShuffleHandle, - mapId: Int, - context: TaskContext): ShuffleWriter[K, V] = { - handle match { - case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] => - numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps) - val env = SparkEnv.get - new UnsafeShuffleWriter( - env.blockManager, - shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], - context.taskMemoryManager(), - env.shuffleMemoryManager, - unsafeShuffleHandle, - mapId, - context, - env.conf) - case other => - shufflesThatFellBackToSortShuffle.add(handle.shuffleId) - sortShuffleManager.getWriter(handle, mapId, context) - } - } - - /** Remove a shuffle's metadata from the ShuffleManager. */ - override def unregisterShuffle(shuffleId: Int): Boolean = { - if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) { - sortShuffleManager.unregisterShuffle(shuffleId) - } else { - Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps => - (0 until numMaps).foreach { mapId => - shuffleBlockResolver.removeDataByMap(shuffleId, mapId) - } - } - true - } - } - - override val shuffleBlockResolver: IndexShuffleBlockResolver = { - sortShuffleManager.shuffleBlockResolver - } - - /** Shut down this ShuffleManager. */ - override def stop(): Unit = { - sortShuffleManager.stop() - } -} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 390c136df79b3..24a0b5220695c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -127,7 +127,7 @@ private[v1] object AllStagesResource { new TaskData( taskId = uiData.taskInfo.taskId, index = uiData.taskInfo.index, - attempt = uiData.taskInfo.attempt, + attempt = uiData.taskInfo.attemptNumber, launchTime = new Date(uiData.taskInfo.launchTime), executorId = uiData.taskInfo.executorId, host = uiData.taskInfo.host, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala new file mode 100644 index 0000000000000..f6e46ae9a481a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.SparkException + +private[spark] +case class BlockFetchException(messages: String, throwable: Throwable) + extends SparkException(messages, throwable) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index fefaef0ab82c8..c374b93766225 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -23,6 +23,7 @@ import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{ExecutionContext, Await, Future} import scala.concurrent.duration._ +import scala.util.control.NonFatal import scala.util.Random import sun.nio.ch.DirectBuffer @@ -30,6 +31,7 @@ import sun.nio.ch.DirectBuffer import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.io.CompressionCodec +import org.apache.spark.memory.MemoryManager import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf @@ -63,8 +65,8 @@ private[spark] class BlockManager( rpcEnv: RpcEnv, val master: BlockManagerMaster, defaultSerializer: Serializer, - maxMemory: Long, val conf: SparkConf, + memoryManager: MemoryManager, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, blockTransferService: BlockTransferService, @@ -81,12 +83,19 @@ private[spark] class BlockManager( // Actual storage of where blocks are kept private var externalBlockStoreInitialized = false - private[spark] val memoryStore = new MemoryStore(this, maxMemory) + private[spark] val memoryStore = new MemoryStore(this, memoryManager) private[spark] val diskStore = new DiskStore(this, diskBlockManager) private[spark] lazy val externalBlockStore: ExternalBlockStore = { externalBlockStoreInitialized = true new ExternalBlockStore(this, executorId) } + memoryManager.setMemoryStore(memoryStore) + + // Note: depending on the memory manager, `maxStorageMemory` may actually vary over time. + // However, since we use this only for reporting and logging, what we actually want here is + // the absolute maximum value that `maxStorageMemory` can ever possibly reach. We may need + // to revisit whether reporting this value as the "max" is intuitive to the user. + private val maxMemory = memoryManager.maxStorageMemory private[spark] val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) @@ -105,15 +114,6 @@ private[spark] class BlockManager( } } - // Check that we're not using external shuffle service with consolidated shuffle files. - if (externalShuffleServiceEnabled - && conf.getBoolean("spark.shuffle.consolidateFiles", false) - && shuffleManager.isInstanceOf[HashShuffleManager]) { - throw new UnsupportedOperationException("Cannot use external shuffle service with consolidated" - + " shuffle files in hash-based shuffle. Please disable spark.shuffle.consolidateFiles or " - + " switch to sort-based shuffle.") - } - var blockManagerId: BlockManagerId = _ // Address of the server that serves this executor's shuffle files. This is either an external @@ -165,24 +165,6 @@ private[spark] class BlockManager( * loaded yet. */ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) - /** - * Construct a BlockManager with a memory limit set based on system properties. - */ - def this( - execId: String, - rpcEnv: RpcEnv, - master: BlockManagerMaster, - serializer: Serializer, - conf: SparkConf, - mapOutputTracker: MapOutputTracker, - shuffleManager: ShuffleManager, - blockTransferService: BlockTransferService, - securityManager: SecurityManager, - numUsableCores: Int) = { - this(execId, rpcEnv, master, serializer, BlockManager.getMaxMemory(conf), - conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) - } - /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -600,10 +582,26 @@ private[spark] class BlockManager( private def doGetRemote(blockId: BlockId, asBlockResult: Boolean): Option[Any] = { require(blockId != null, "BlockId is null") val locations = Random.shuffle(master.getLocations(blockId)) + var numFetchFailures = 0 for (loc <- locations) { logDebug(s"Getting remote block $blockId from $loc") - val data = blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() + val data = try { + blockTransferService.fetchBlockSync( + loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() + } catch { + case NonFatal(e) => + numFetchFailures += 1 + if (numFetchFailures == locations.size) { + // An exception is thrown while fetching this block from all locations + throw new BlockFetchException(s"Failed to fetch block from" + + s" ${locations.size} locations. Most recent failure cause:", e) + } else { + // This location failed, so we retry fetch from a different one by returning null here + logWarning(s"Failed to fetch remote block $blockId " + + s"from $loc (failed attempt $numFetchFailures)", e) + null + } + } if (data != null) { if (asBlockResult) { @@ -661,7 +659,7 @@ private[spark] class BlockManager( writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream, + new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream, syncWrites, writeMetrics) } @@ -1259,13 +1257,6 @@ private[spark] class BlockManager( private[spark] object BlockManager extends Logging { private val ID_GENERATOR = new IdGenerator - /** Return the total amount of storage memory available. */ - private def getMaxMemory(conf: SparkConf): Long = { - val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6) - val safetyFraction = conf.getDouble("spark.storage.safetyFraction", 0.9) - (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong - } - /** * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that * might cause errors if one attempts to read from the unmapped buffer, but it's better than diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 7478ab0fc2f7a..e749631bf6f19 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext, RpcEndpoint} import org.apache.spark.util.ThreadUtils import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} import org.apache.spark.storage.BlockManagerMessages._ @@ -33,7 +33,7 @@ class BlockManagerSlaveEndpoint( override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends RpcEndpoint with Logging { + extends ThreadSafeRpcEndpoint with Logging { private val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") @@ -80,7 +80,7 @@ class BlockManagerSlaveEndpoint( future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) context.reply(response) - logDebug("Sent response: " + response + " to " + context.sender) + logDebug("Sent response: " + response + " to " + context.senderAddress) } future.onFailure { case t: Throwable => logError("Error in " + actionMessage, t) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 3f8d26e1d4cab..f7e84a2c2e14c 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -164,7 +164,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon private def doStop(): Unit = { // Only perform cleanup if an external service is not serving our shuffle files. - if (!blockManager.externalShuffleServiceEnabled || blockManager.blockManagerId.isDriver) { + // Also blockManagerId could be null if block manager is not initialized properly. + if (!blockManager.externalShuffleServiceEnabled || + (blockManager.blockManagerId != null && blockManager.blockManagerId.isDriver)) { localDirs.foreach { localDir => if (localDir.isDirectory() && localDir.exists()) { try { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index 49d9154f95a5b..80d426fadc65e 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -34,7 +34,6 @@ import org.apache.spark.util.Utils * reopened again. */ private[spark] class DiskBlockObjectWriter( - val blockId: BlockId, file: File, serializerInstance: SerializerInstance, bufferSize: Int, @@ -144,8 +143,10 @@ private[spark] class DiskBlockObjectWriter( * Reverts writes that haven't been flushed yet. Callers should invoke this function * when there are runtime exceptions. This method will not throw, though it may be * unsuccessful in truncating written data. + * + * @return the file that this DiskBlockObjectWriter wrote to. */ - def revertPartialWritesAndClose() { + def revertPartialWritesAndClose(): File = { // Discard current writes. We do this by flushing the outstanding writes and then // truncating the file to its initial position. try { @@ -160,12 +161,14 @@ private[spark] class DiskBlockObjectWriter( val truncateStream = new FileOutputStream(file, true) try { truncateStream.getChannel.truncate(initialPosition) + file } finally { truncateStream.close() } } catch { case e: Exception => logError("Uncaught exception while reverting partial writes to file " + file, e) + file } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 1f45956282166..c008b9dc16327 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -86,7 +86,9 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc } catch { case e: Throwable => if (file.exists()) { - file.delete() + if (!file.delete()) { + logWarning(s"Error deleting ${file}") + } } throw e } @@ -154,11 +156,12 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc override def remove(blockId: BlockId): Boolean = { val file = diskManager.getFile(blockId.name) - // If consolidation mode is used With HashShuffleMananger, the physical filename for the block - // is different from blockId.name. So the file returns here will not be exist, thus we avoid to - // delete the whole consolidated file by mistake. if (file.exists()) { - file.delete() + val ret = file.delete() + if (!ret) { + logWarning(s"Error deleting ${file.getPath()}") + } + ret } else { false } diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 6f27f00307f8c..4dbac388e098b 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.TaskContext +import org.apache.spark.memory.MemoryManager import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector @@ -33,19 +34,17 @@ private case class MemoryEntry(value: Any, size: Long, deserialized: Boolean) * Stores blocks in memory, either as Arrays of deserialized Java objects or as * serialized ByteBuffers. */ -private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) +private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: MemoryManager) extends BlockStore(blockManager) { + // Note: all changes to memory allocations, notably putting blocks, evicting blocks, and + // acquiring or releasing unroll memory, must be synchronized on `memoryManager`! + private val conf = blockManager.conf private val entries = new LinkedHashMap[BlockId, MemoryEntry](32, 0.75f, true) - @volatile private var currentMemory = 0L - - // Ensure only one thread is putting, and if necessary, dropping blocks at any given time - private val accountingLock = new Object - // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes) - // All accesses of this map are assumed to have manually synchronized on `accountingLock` + // All accesses of this map are assumed to have manually synchronized on `memoryManager` private val unrollMemoryMap = mutable.HashMap[Long, Long]() // Same as `unrollMemoryMap`, but for pending unroll memory as defined below. // Pending unroll memory refers to the intermediate memory occupied by a task @@ -56,19 +55,13 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // memory (SPARK-4777). private val pendingUnrollMemoryMap = mutable.HashMap[Long, Long]() - /** - * The amount of space ensured for unrolling values in memory, shared across all cores. - * This space is not reserved in advance, but allocated dynamically by dropping existing blocks. - */ - private val maxUnrollMemory: Long = { - val unrollFraction = conf.getDouble("spark.storage.unrollFraction", 0.2) - (maxMemory * unrollFraction).toLong - } - // Initial memory to request before unrolling any block private val unrollMemoryThreshold: Long = conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024) + /** Total amount of memory available for storage, in bytes. */ + private def maxMemory: Long = memoryManager.maxStorageMemory + if (maxMemory < unrollMemoryThreshold) { logWarning(s"Max memory ${Utils.bytesToString(maxMemory)} is less than the initial memory " + s"threshold ${Utils.bytesToString(unrollMemoryThreshold)} needed to store a block in " + @@ -77,8 +70,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory))) - /** Free memory not occupied by existing blocks. Note that this does not include unroll memory. */ - def freeMemory: Long = maxMemory - currentMemory + /** Total storage memory used including unroll memory, in bytes. */ + private def memoryUsed: Long = memoryManager.storageMemoryUsed + + /** + * Amount of storage memory, in bytes, used for caching blocks. + * This does not include memory used for unrolling. + */ + private def blocksMemoryUsed: Long = memoryManager.synchronized { + memoryUsed - currentUnrollMemory + } override def getSize(blockId: BlockId): Long = { entries.synchronized { @@ -94,8 +95,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val values = blockManager.dataDeserialize(blockId, bytes) putIterator(blockId, values, level, returnValues = true) } else { - val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false) - PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks) + val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + tryToPut(blockId, bytes, bytes.limit, deserialized = false, droppedBlocks) + PutResult(bytes.limit(), Right(bytes.duplicate()), droppedBlocks) } } @@ -108,15 +110,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): PutResult = { // Work on a duplicate - since the original input might be used elsewhere. lazy val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer] - val putAttempt = tryToPut(blockId, () => bytes, size, deserialized = false) + val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + val putSuccess = tryToPut(blockId, () => bytes, size, deserialized = false, droppedBlocks) val data = - if (putAttempt.success) { + if (putSuccess) { assert(bytes.limit == size) Right(bytes.duplicate()) } else { null } - PutResult(size, data, putAttempt.droppedBlocks) + PutResult(size, data, droppedBlocks) } override def putArray( @@ -124,14 +127,15 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) values: Array[Any], level: StorageLevel, returnValues: Boolean): PutResult = { + val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] if (level.deserialized) { val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef]) - val putAttempt = tryToPut(blockId, values, sizeEstimate, deserialized = true) - PutResult(sizeEstimate, Left(values.iterator), putAttempt.droppedBlocks) + tryToPut(blockId, values, sizeEstimate, deserialized = true, droppedBlocks) + PutResult(sizeEstimate, Left(values.iterator), droppedBlocks) } else { val bytes = blockManager.dataSerialize(blockId, values.iterator) - val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false) - PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks) + tryToPut(blockId, bytes, bytes.limit, deserialized = false, droppedBlocks) + PutResult(bytes.limit(), Right(bytes.duplicate()), droppedBlocks) } } @@ -208,24 +212,25 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def remove(blockId: BlockId): Boolean = { - entries.synchronized { - val entry = entries.remove(blockId) - if (entry != null) { - currentMemory -= entry.size - logDebug(s"Block $blockId of size ${entry.size} dropped from memory (free $freeMemory)") - true - } else { - false - } + override def remove(blockId: BlockId): Boolean = memoryManager.synchronized { + val entry = entries.synchronized { entries.remove(blockId) } + if (entry != null) { + memoryManager.releaseStorageMemory(entry.size) + logDebug(s"Block $blockId of size ${entry.size} dropped " + + s"from memory (free ${maxMemory - blocksMemoryUsed})") + true + } else { + false } } - override def clear() { + override def clear(): Unit = memoryManager.synchronized { entries.synchronized { entries.clear() - currentMemory = 0 } + unrollMemoryMap.clear() + pendingUnrollMemoryMap.clear() + memoryManager.releaseAllStorageMemory() logInfo("MemoryStore cleared") } @@ -265,7 +270,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) var vector = new SizeTrackingVector[Any] // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisTask(initialMemoryThreshold) + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, droppedBlocks) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -281,20 +286,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val currentSize = vector.estimateSize() if (currentSize >= memoryThreshold) { val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong - // Hold the accounting lock, in case another thread concurrently puts a block that - // takes up the unrolling space we just ensured here - accountingLock.synchronized { - if (!reserveUnrollMemoryForThisTask(amountToRequest)) { - // If the first request is not granted, try again after ensuring free space - // If there is still not enough space, give up and drop the partition - val spaceToEnsure = maxUnrollMemory - currentUnrollMemory - if (spaceToEnsure > 0) { - val result = ensureFreeSpace(blockId, spaceToEnsure) - droppedBlocks ++= result.droppedBlocks - } - keepUnrolling = reserveUnrollMemoryForThisTask(amountToRequest) - } - } + keepUnrolling = reserveUnrollMemoryForThisTask( + blockId, amountToRequest, droppedBlocks) // New threshold is currentSize * memoryGrowthFactor memoryThreshold += amountToRequest } @@ -312,16 +305,23 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } finally { - // If we return an array, the values returned will later be cached in `tryToPut`. - // In this case, we should release the memory after we cache the block there. - // Otherwise, if we return an iterator, we release the memory reserved here - // later when the task finishes. + // If we return an array, the values returned here will be cached in `tryToPut` later. + // In this case, we should release the memory only after we cache the block there. if (keepUnrolling) { - accountingLock.synchronized { - val amountToRelease = currentUnrollMemoryForThisTask - previousMemoryReserved - releaseUnrollMemoryForThisTask(amountToRelease) - reservePendingUnrollMemoryForThisTask(amountToRelease) + val taskAttemptId = currentTaskAttemptId() + memoryManager.synchronized { + // Since we continue to hold onto the array until we actually cache it, we cannot + // release the unroll memory yet. Instead, we transfer it to pending unroll memory + // so `tryToPut` can further transfer it to normal storage memory later. + // TODO: we can probably express this without pending unroll memory (SPARK-10907) + val amountToTransferToPending = currentUnrollMemoryForThisTask - previousMemoryReserved + unrollMemoryMap(taskAttemptId) -= amountToTransferToPending + pendingUnrollMemoryMap(taskAttemptId) = + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + amountToTransferToPending } + } else { + // Otherwise, if we return an iterator, we can only release the unroll memory when + // the task finishes since we don't know when the iterator will be consumed. } } } @@ -337,8 +337,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) blockId: BlockId, value: Any, size: Long, - deserialized: Boolean): ResultWithDroppedBlocks = { - tryToPut(blockId, () => value, size, deserialized) + deserialized: Boolean, + droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + tryToPut(blockId, () => value, size, deserialized, droppedBlocks) } /** @@ -349,18 +350,21 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * `value` will be lazily created. If it cannot be put into MemoryStore or disk, `value` won't be * created to avoid OOM since it may be a big ByteBuffer. * - * Synchronize on `accountingLock` to ensure that all the put requests and its associated block + * Synchronize on `memoryManager` to ensure that all the put requests and its associated block * dropping is done by only on thread at a time. Otherwise while one thread is dropping * blocks to free memory for one block, another thread may use up the freed space for * another block. * - * Return whether put was successful, along with the blocks dropped in the process. + * All blocks evicted in the process, if any, will be added to `droppedBlocks`. + * + * @return whether put was successful. */ private def tryToPut( blockId: BlockId, value: () => Any, size: Long, - deserialized: Boolean): ResultWithDroppedBlocks = { + deserialized: Boolean, + droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { /* TODO: Its possible to optimize the locking by locking entries only when selecting blocks * to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has @@ -368,24 +372,24 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * for freeing up more space for another block that needs to be put. Only then the actually * dropping of blocks (and writing to disk if necessary) can proceed in parallel. */ - var putSuccess = false - val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - - accountingLock.synchronized { - val freeSpaceResult = ensureFreeSpace(blockId, size) - val enoughFreeSpace = freeSpaceResult.success - droppedBlocks ++= freeSpaceResult.droppedBlocks - - if (enoughFreeSpace) { + memoryManager.synchronized { + // Note: if we have previously unrolled this block successfully, then pending unroll + // memory should be non-zero. This is the amount that we already reserved during the + // unrolling process. In this case, we can just reuse this space to cache our block. + // The synchronization on `memoryManager` here guarantees that the release and acquire + // happen atomically. This relies on the assumption that all memory acquisitions are + // synchronized on the same lock. + releasePendingUnrollMemoryForThisTask() + val enoughMemory = memoryManager.acquireStorageMemory(blockId, size, droppedBlocks) + if (enoughMemory) { + // We acquired enough memory for the block, so go ahead and put it val entry = new MemoryEntry(value(), size, deserialized) entries.synchronized { entries.put(blockId, entry) - currentMemory += size } val valuesOrBytes = if (deserialized) "values" else "bytes" logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format( - blockId, valuesOrBytes, Utils.bytesToString(size), Utils.bytesToString(freeMemory))) - putSuccess = true + blockId, valuesOrBytes, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed))) } else { // Tell the block manager that we couldn't put it in memory so that it can drop it to // disk if the block allows disk storage. @@ -397,10 +401,35 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val droppedBlockStatus = blockManager.dropFromMemory(blockId, () => data) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } - // Release the unroll memory used because we no longer need the underlying Array - releasePendingUnrollMemoryForThisTask() + enoughMemory } - ResultWithDroppedBlocks(putSuccess, droppedBlocks) + } + + /** + * Try to free up a given amount of space by evicting existing blocks. + * + * @param space the amount of memory to free, in bytes + * @param droppedBlocks a holder for blocks evicted in the process + * @return whether the requested free space is freed. + */ + private[spark] def ensureFreeSpace( + space: Long, + droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + ensureFreeSpace(None, space, droppedBlocks) + } + + /** + * Try to free up a given amount of space to store a block by evicting existing ones. + * + * @param space the amount of memory to free, in bytes + * @param droppedBlocks a holder for blocks evicted in the process + * @return whether the requested free space is freed. + */ + private[spark] def ensureFreeSpace( + blockId: BlockId, + space: Long, + droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + ensureFreeSpace(Some(blockId), space, droppedBlocks) } /** @@ -409,40 +438,43 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that * don't fit into memory that we want to avoid). * - * Assume that `accountingLock` is held by the caller to ensure only one thread is dropping - * blocks. Otherwise, the freed space may fill up before the caller puts in their new value. - * - * Return whether there is enough free space, along with the blocks dropped in the process. + * @param blockId the ID of the block we are freeing space for, if any + * @param space the size of this block + * @param droppedBlocks a holder for blocks evicted in the process + * @return whether the requested free space is freed. */ private def ensureFreeSpace( - blockIdToAdd: BlockId, - space: Long): ResultWithDroppedBlocks = { - logInfo(s"ensureFreeSpace($space) called with curMem=$currentMemory, maxMem=$maxMemory") - - val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + blockId: Option[BlockId], + space: Long, + droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + memoryManager.synchronized { + val freeMemory = maxMemory - memoryUsed + val rddToAdd = blockId.flatMap(getRddId) + val selectedBlocks = new ArrayBuffer[BlockId] + var selectedMemory = 0L - if (space > maxMemory) { - logInfo(s"Will not store $blockIdToAdd as it is larger than our memory limit") - return ResultWithDroppedBlocks(success = false, droppedBlocks) - } + logInfo(s"Ensuring $space bytes of free space " + + blockId.map { id => s"for block $id" }.getOrElse("") + + s"(free: $freeMemory, max: $maxMemory)") - // Take into account the amount of memory currently occupied by unrolling blocks - // and minus the pending unroll memory for that block on current thread. - val taskAttemptId = currentTaskAttemptId() - val actualFreeMemory = freeMemory - currentUnrollMemory + - pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + // Fail fast if the block simply won't fit + if (space > maxMemory) { + logInfo("Will not " + blockId.map { id => s"store $id" }.getOrElse("free memory") + + s" as the required space ($space bytes) exceeds our memory limit ($maxMemory bytes)") + return false + } - if (actualFreeMemory < space) { - val rddToAdd = getRddId(blockIdToAdd) - val selectedBlocks = new ArrayBuffer[BlockId] - var selectedMemory = 0L + // No need to evict anything if there is already enough free space + if (freeMemory >= space) { + return true + } // This is synchronized to ensure that the set of entries is not changed // (because of getValue or getBytes) while traversing the iterator, as that // can lead to exceptions. entries.synchronized { val iterator = entries.entrySet().iterator() - while (actualFreeMemory + selectedMemory < space && iterator.hasNext) { + while (freeMemory + selectedMemory < space && iterator.hasNext) { val pair = iterator.next() val blockId = pair.getKey if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) { @@ -452,7 +484,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - if (actualFreeMemory + selectedMemory >= space) { + if (freeMemory + selectedMemory >= space) { logInfo(s"${selectedBlocks.size} blocks selected for dropping") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } @@ -469,14 +501,15 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } } - return ResultWithDroppedBlocks(success = true, droppedBlocks) + true } else { - logInfo(s"Will not store $blockIdToAdd as it would require dropping another block " + - "from the same RDD") - return ResultWithDroppedBlocks(success = false, droppedBlocks) + blockId.foreach { id => + logInfo(s"Will not store $id as it would require dropping another block " + + "from the same RDD") + } + false } } - ResultWithDroppedBlocks(success = true, droppedBlocks) } override def contains(blockId: BlockId): Boolean = { @@ -489,17 +522,20 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } /** - * Reserve additional memory for unrolling blocks used by this task. - * Return whether the request is granted. + * Reserve memory for unrolling the given block for this task. + * @return whether the request is granted. */ - def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { - accountingLock.synchronized { - val granted = freeMemory > currentUnrollMemory + memory - if (granted) { + def reserveUnrollMemoryForThisTask( + blockId: BlockId, + memory: Long, + droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = { + memoryManager.synchronized { + val success = memoryManager.acquireUnrollMemory(blockId, memory, droppedBlocks) + if (success) { val taskAttemptId = currentTaskAttemptId() unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } - granted + success } } @@ -507,73 +543,68 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Release memory used by this task for unrolling blocks. * If the amount is not specified, remove the current task's allocation altogether. */ - def releaseUnrollMemoryForThisTask(memory: Long = -1L): Unit = { + def releaseUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = { val taskAttemptId = currentTaskAttemptId() - accountingLock.synchronized { - if (memory < 0) { - unrollMemoryMap.remove(taskAttemptId) - } else { - unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, memory) - memory - // If this task claims no more unroll memory, release it completely - if (unrollMemoryMap(taskAttemptId) <= 0) { - unrollMemoryMap.remove(taskAttemptId) + memoryManager.synchronized { + if (unrollMemoryMap.contains(taskAttemptId)) { + val memoryToRelease = math.min(memory, unrollMemoryMap(taskAttemptId)) + if (memoryToRelease > 0) { + unrollMemoryMap(taskAttemptId) -= memoryToRelease + if (unrollMemoryMap(taskAttemptId) == 0) { + unrollMemoryMap.remove(taskAttemptId) + } + memoryManager.releaseUnrollMemory(memoryToRelease) } } } } - /** - * Reserve the unroll memory of current unroll successful block used by this task - * until actually put the block into memory entry. - */ - def reservePendingUnrollMemoryForThisTask(memory: Long): Unit = { - val taskAttemptId = currentTaskAttemptId() - accountingLock.synchronized { - pendingUnrollMemoryMap(taskAttemptId) = - pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory - } - } - /** * Release pending unroll memory of current unroll successful block used by this task */ - def releasePendingUnrollMemoryForThisTask(): Unit = { + def releasePendingUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = { val taskAttemptId = currentTaskAttemptId() - accountingLock.synchronized { - pendingUnrollMemoryMap.remove(taskAttemptId) + memoryManager.synchronized { + if (pendingUnrollMemoryMap.contains(taskAttemptId)) { + val memoryToRelease = math.min(memory, pendingUnrollMemoryMap(taskAttemptId)) + if (memoryToRelease > 0) { + pendingUnrollMemoryMap(taskAttemptId) -= memoryToRelease + if (pendingUnrollMemoryMap(taskAttemptId) == 0) { + pendingUnrollMemoryMap.remove(taskAttemptId) + } + memoryManager.releaseUnrollMemory(memoryToRelease) + } + } } } /** * Return the amount of memory currently occupied for unrolling blocks across all tasks. */ - def currentUnrollMemory: Long = accountingLock.synchronized { + def currentUnrollMemory: Long = memoryManager.synchronized { unrollMemoryMap.values.sum + pendingUnrollMemoryMap.values.sum } /** * Return the amount of memory currently occupied for unrolling blocks by this task. */ - def currentUnrollMemoryForThisTask: Long = accountingLock.synchronized { + def currentUnrollMemoryForThisTask: Long = memoryManager.synchronized { unrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) } /** * Return the number of tasks currently unrolling blocks. */ - def numTasksUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } + private def numTasksUnrolling: Int = memoryManager.synchronized { unrollMemoryMap.keys.size } /** * Log information about current memory usage. */ - def logMemoryUsage(): Unit = { - val blocksMemory = currentMemory - val unrollMemory = currentUnrollMemory - val totalMemory = blocksMemory + unrollMemory + private def logMemoryUsage(): Unit = { logInfo( - s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " + - s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " + - s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(totalMemory)}. " + + s"Memory use = ${Utils.bytesToString(blocksMemoryUsed)} (blocks) + " + + s"${Utils.bytesToString(currentUnrollMemory)} (scratch space shared across " + + s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(memoryUsed)}. " + s"Storage limit = ${Utils.bytesToString(maxMemory)}." ) } @@ -584,7 +615,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * @param blockId ID of the block we are trying to unroll. * @param finalVectorSize Final size of the vector before unrolling failed. */ - def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = { + private def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = { logWarning( s"Not enough space to cache $blockId in memory! " + s"(computed ${Utils.bytesToString(finalVectorSize)} so far)" @@ -592,7 +623,3 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logMemoryUsage() } } - -private[spark] case class ResultWithDroppedBlocks( - success: Boolean, - droppedBlocks: Seq[(BlockId, BlockStatus)]) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index a759ceb96ec1e..0d0448feb5b06 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -260,10 +260,7 @@ final class ShuffleBlockFetcherIterator( fetchRequests ++= Utils.randomize(remoteRequests) // Send out initial requests for blocks, up to our maxBytesInFlight - while (fetchRequests.nonEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } + fetchUpToMaxBytes() val numFetches = remoteRequests.size - fetchRequests.size logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime)) @@ -296,10 +293,7 @@ final class ShuffleBlockFetcherIterator( case _ => } // Send fetch requests up to maxBytesInFlight - while (fetchRequests.nonEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } + fetchUpToMaxBytes() result match { case FailureFetchResult(blockId, address, e) => @@ -315,6 +309,14 @@ final class ShuffleBlockFetcherIterator( } } + private def fetchUpToMaxBytes(): Unit = { + // Send fetch requests up to maxBytesInFlight + while (fetchRequests.nonEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + } + private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = { blockId match { case ShuffleBlockId(shufId, mapId, reduceId) => diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 779c0ba083596..b796a44fe01ac 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -59,7 +59,17 @@ private[spark] object JettyUtils extends Logging { def createServlet[T <% AnyRef]( servletParams: ServletParams[T], - securityMgr: SecurityManager): HttpServlet = { + securityMgr: SecurityManager, + conf: SparkConf): HttpServlet = { + + // SPARK-10589 avoid frame-related click-jacking vulnerability, using X-Frame-Options + // (see http://tools.ietf.org/html/rfc7034). By default allow framing only from the + // same origin, but allow framing for a specific named URI. + // Example: spark.ui.allowFramingFrom = https://example.com/ + val allowFramingFrom = conf.getOption("spark.ui.allowFramingFrom") + val xFrameOptionsValue = + allowFramingFrom.map(uri => s"ALLOW-FROM $uri").getOrElse("SAMEORIGIN") + new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse) { try { @@ -68,6 +78,7 @@ private[spark] object JettyUtils extends Logging { response.setStatus(HttpServletResponse.SC_OK) val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.setHeader("X-Frame-Options", xFrameOptionsValue) // scalastyle:off println response.getWriter.println(servletParams.extractFn(result)) // scalastyle:on println @@ -97,8 +108,9 @@ private[spark] object JettyUtils extends Logging { path: String, servletParams: ServletParams[T], securityMgr: SecurityManager, + conf: SparkConf, basePath: String = ""): ServletContextHandler = { - createServletHandler(path, createServlet(servletParams, securityMgr), basePath) + createServletHandler(path, createServlet(servletParams, securityMgr, conf), basePath) } /** Create a context handler that responds to a request with the given path prefix */ diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index d8b90568b7b9a..99085ada9f0af 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -56,6 +56,8 @@ private[spark] class SparkUI private ( val stagesTab = new StagesTab(this) + var appId: String = _ + /** Initialize all components of the server. */ def initialize() { attachTab(new JobsTab(this)) @@ -75,9 +77,8 @@ private[spark] class SparkUI private ( def getAppName: String = appName - /** Set the app name for this UI. */ - def setAppName(name: String) { - appName = name + def setAppId(id: String): Unit = { + appId = id } /** Stop the server behind this web interface. Only valid after bind(). */ @@ -94,12 +95,12 @@ private[spark] class SparkUI private ( private[spark] def appUIAddress = s"http://$appUIHostPort" def getSparkUI(appId: String): Option[SparkUI] = { - if (appId == appName) Some(this) else None + if (appId == this.appId) Some(this) else None } def getApplicationInfoList: Iterator[ApplicationInfo] = { Iterator(new ApplicationInfo( - id = appName, + id = appId, name = appName, attempts = Seq(new ApplicationAttemptInfo( attemptId = None, diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index f2da417724104..68a9f912a5d2c 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -18,9 +18,11 @@ package org.apache.spark.ui import java.text.SimpleDateFormat -import java.util.{Locale, Date} +import java.util.{Date, Locale} -import scala.xml.{Node, Text, Unparsed} +import scala.util.control.NonFatal +import scala.xml._ +import scala.xml.transform.{RewriteRule, RuleTransformer} import org.apache.spark.Logging import org.apache.spark.ui.scope.RDDOperationGraph @@ -29,6 +31,7 @@ import org.apache.spark.ui.scope.RDDOperationGraph private[spark] object UIUtils extends Logging { val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed" val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" + val TABLE_CLASS_STRIPED_SORTABLE = TABLE_CLASS_STRIPED + " sortable" // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { @@ -395,4 +398,60 @@ private[spark] object UIUtils extends Logging { } + /** + * Returns HTML rendering of a job or stage description. It will try to parse the string as HTML + * and make sure that it only contains anchors with root-relative links. Otherwise, + * the whole string will rendered as a simple escaped text. + * + * Note: In terms of security, only anchor tags with root relative links are supported. So any + * attempts to embed links outside Spark UI, or other tags like +
+ } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 5779c71f64e9e..5a072de400b6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,19 +19,15 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable -import com.google.common.annotations.VisibleForTesting - -import org.apache.spark.{JobExecutionStatus, Logging} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} +import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} -private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener with Logging { +private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Logging { - private val retainedExecutions = - sqlContext.sparkContext.conf.getInt("spark.sql.ui.retainedExecutions", 1000) + private val retainedExecutions = conf.getInt("spark.sql.ui.retainedExecutions", 1000) private val activeExecutions = mutable.HashMap[Long, SQLExecutionUIData]() @@ -130,7 +126,13 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit val stageId = stageSubmitted.stageInfo.stageId val stageAttemptId = stageSubmitted.stageInfo.attemptId // Always override metrics for old stage attempt - _stageIdToStageMetrics(stageId) = new SQLStageMetrics(stageAttemptId) + if (_stageIdToStageMetrics.contains(stageId)) { + _stageIdToStageMetrics(stageId) = new SQLStageMetrics(stageAttemptId) + } else { + // If a stage belongs to some SQL execution, its stageId will be put in "onJobStart". + // Since "_stageIdToStageMetrics" doesn't contain it, it must not belong to any SQL execution. + // So we can ignore it. Otherwise, this may lead to memory leaks (SPARK-11126). + } } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { @@ -256,7 +258,7 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit /** * Get all accumulator updates from all tasks which belong to this execution and merge them. */ - def getExecutionMetrics(executionId: Long): Map[Long, Any] = synchronized { + def getExecutionMetrics(executionId: Long): Map[Long, String] = synchronized { _executionIdToData.get(executionId) match { case Some(executionUIData) => val accumulatorUpdates = { @@ -268,8 +270,7 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit } }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) } mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId => - executionUIData.accumulatorMetrics(accumulatorId).metricParam). - mapValues(_.asInstanceOf[SQLMetricValue[_]].value) + executionUIData.accumulatorMetrics(accumulatorId).metricParam) case None => // This execution has been dropped Map.empty @@ -278,11 +279,11 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit private def mergeAccumulatorUpdates( accumulatorUpdates: Seq[(Long, Any)], - paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, Any] = { + paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, String] = { accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) => val param = paramFunc(accumulatorId) (accumulatorId, - values.map(_._2.asInstanceOf[SQLMetricValue[Any]]).foldLeft(param.zero)(param.addInPlace)) + param.stringValue(values.map(_._2.asInstanceOf[SQLMetricValue[Any]].value))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index 0b0867f67eb6e..9c27944d42fc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -20,14 +20,12 @@ package org.apache.spark.sql.execution.ui import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.Logging -import org.apache.spark.sql.SQLContext import org.apache.spark.ui.{SparkUI, SparkUITab} -private[sql] class SQLTab(sqlContext: SQLContext, sparkUI: SparkUI) +private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) extends SparkUITab(sparkUI, SQLTab.nextTabName) with Logging { val parent = sparkUI - val listener = sqlContext.listener attachPage(new AllExecutionsPage(this)) attachPage(new ExecutionPage(this)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index ae3d752dde348..f1fce5478a3fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} private[ui] case class SparkPlanGraph( nodes: Seq[SparkPlanGraphNode], edges: Seq[SparkPlanGraphEdge]) { - def makeDotFile(metrics: Map[Long, Any]): String = { + def makeDotFile(metrics: Map[Long, String]): String = { val dotFile = new StringBuilder dotFile.append("digraph G {\n") nodes.foreach(node => dotFile.append(node.makeDotNode(metrics) + "\n")) @@ -87,7 +87,7 @@ private[sql] object SparkPlanGraph { private[ui] case class SparkPlanGraphNode( id: Long, name: String, desc: String, metrics: Seq[SQLPlanMetric]) { - def makeDotNode(metricsValue: Map[Long, Any]): String = { + def makeDotNode(metricsValue: Map[Long, String]): String = { val values = { for (metric <- metrics; value <- metricsValue.get(metric.accumulatorId)) yield { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index c3d2246297021..8b9247adea200 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.types.BooleanType import org.apache.spark.sql.{Column, catalyst} import org.apache.spark.sql.catalyst.expressions._ @@ -149,13 +150,17 @@ class WindowSpec private[sql]( case Count(child) => WindowExpression( UnresolvedWindowFunction("count", child :: Nil), WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case First(child) => WindowExpression( + case First(child, ignoreNulls) => WindowExpression( // TODO this is a hack for Hive UDAF first_value - UnresolvedWindowFunction("first_value", child :: Nil), + UnresolvedWindowFunction( + "first_value", + child :: ignoreNulls :: Nil), WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Last(child) => WindowExpression( + case Last(child, ignoreNulls) => WindowExpression( // TODO this is a hack for Hive UDAF last_value - UnresolvedWindowFunction("last_value", child :: Nil), + UnresolvedWindowFunction( + "last_value", + child :: ignoreNulls :: Nil), WindowSpecDefinition(partitionSpec, orderSpec, frame)) case Min(child) => WindowExpression( UnresolvedWindowFunction("min", child :: Nil), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 435e6319a64c4..15c864a8ab641 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.Utils * @groupname window_funcs Window functions * @groupname string_funcs String functions * @groupname collection_funcs Collection functions - * @groupname Ungrouped Support functions for DataFrames. + * @groupname Ungrouped Support functions for DataFrames * @since 1.3.0 */ @Experimental @@ -294,6 +294,33 @@ object functions { */ def min(columnName: String): Column = min(Column(columnName)) + /** + * Aggregate function: returns the unbiased sample standard deviation + * of the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev(e: Column): Column = Stddev(e.expr) + + /** + * Aggregate function: returns the population standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_pop(e: Column): Column = StddevPop(e.expr) + + /** + * Aggregate function: returns the unbiased sample standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_samp(e: Column): Column = StddevSamp(e.expr) + /** * Aggregate function: returns the sum of all values in the expression. * @@ -796,7 +823,7 @@ object functions { * * @group normal_funcs */ - def expr(expr: String): Column = Column(new SqlParser().parseExpression(expr)) + def expr(expr: String): Column = Column(SqlParser.parseExpression(expr)) ////////////////////////////////////////////////////////////////////////////////////////////// // Math Functions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index c6d05c9b83b98..88ae83957a708 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -88,6 +88,17 @@ abstract class JdbcDialect { def quoteIdentifier(colName: String): String = { s""""$colName"""" } + + /** + * Get the SQL query that should be used to find if the given table exists. Dialects can + * override this method to return a query that works best in a particular database. + * @param table The name of the table. + * @return The SQL query to use for checking the table. + */ + def getTableExistsQuery(table: String): String = { + s"SELECT * FROM $table WHERE 1=0" + } + } /** @@ -126,6 +137,9 @@ object JdbcDialects { registerDialect(MySQLDialect) registerDialect(PostgresDialect) registerDialect(DB2Dialect) + registerDialect(MsSqlServerDialect) + registerDialect(DerbyDialect) + /** * Fetch the JdbcDialect class corresponding to a given database url. @@ -184,11 +198,15 @@ case object PostgresDialect extends JdbcDialect { override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { - Some(BinaryType) + Option(BinaryType) } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { - Some(StringType) + Option(StringType) } else if (sqlType == Types.OTHER && typeName.equals("inet")) { - Some(StringType) + Option(StringType) + } else if (sqlType == Types.OTHER && typeName.equals("json")) { + Option(StringType) + } else if (sqlType == Types.OTHER && typeName.equals("jsonb")) { + Option(StringType) } else None } @@ -198,6 +216,11 @@ case object PostgresDialect extends JdbcDialect { case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) case _ => None } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } + } /** @@ -213,15 +236,19 @@ case object MySQLDialect extends JdbcDialect { // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as // byte arrays instead of longs. md.putLong("binarylong", 1) - Some(LongType) + Option(LongType) } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { - Some(BooleanType) + Option(BooleanType) } else None } override def quoteIdentifier(colName: String): String = { s"`$colName`" } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } } /** @@ -240,3 +267,51 @@ case object DB2Dialect extends JdbcDialect { case _ => None } } + +/** + * :: DeveloperApi :: + * Default Microsoft SQL Server dialect, mapping the datetimeoffset types to a String on read. + */ +@DeveloperApi +case object MsSqlServerDialect extends JdbcDialect { + override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (typeName.contains("datetimeoffset")) { + // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients + Option(StringType) + } else None + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case TimestampType => Some(JdbcType("DATETIME", java.sql.Types.TIMESTAMP)) + case _ => None + } +} + +/** + * :: DeveloperApi :: + * Default Apache Derby dialect, mapping real on read + * and string/byte/short/boolean/decimal on write. + */ +@DeveloperApi +case object DerbyDialect extends JdbcDialect { + override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.REAL) Option(FloatType) else None + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("CLOB", java.sql.Types.CLOB)) + case ByteType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) + case ShortType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) + case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + // 31 is the maximum precision and 5 is the default scale for a Derby DECIMAL + case (t: DecimalType) if (t.precision > 31) => + Some(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) + case _ => None + } + +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7b030b7d73bd5..84eef0f8a672c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.{FileRelation, RDDConversions} import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql._ import org.apache.spark.util.SerializableConfiguration @@ -544,11 +544,32 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio } private def discoverPartitions(): PartitionSpec = { - val typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled() // We use leaf dirs containing data files to discover the schema. val leafDirs = fileStatusCache.leafDirToChildrenFiles.keys.toSeq - PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, - typeInference) + userDefinedPartitionColumns match { + case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => + val spec = PartitioningUtils.parsePartitions( + leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, typeInference = false) + + // Without auto inference, all of value in the `row` should be null or in StringType, + // we need to cast into the data type that user specified. + def castPartitionValuesToUserSchema(row: InternalRow) = { + InternalRow((0 until row.numFields).map { i => + Cast( + Literal.create(row.getString(i), StringType), + userProvidedSchema.fields(i).dataType).eval() + }: _*) + } + + PartitionSpec(userProvidedSchema, spec.partitions.map { part => + part.copy(values = castPartitionValuesToUserSchema(part.values)) + }) + + case _ => + // user did not provide a partitioning schema + PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled()) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 2fdd798b44bb6..8d4854b698ed7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.test -import java.util - -import scala.collection.JavaConverters._ +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} import org.apache.spark.sql.types._ /** @@ -39,22 +37,20 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] { override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" - override def serialize(obj: Any): Seq[Double] = { + override def serialize(obj: Any): GenericArrayData = { obj match { case p: ExamplePoint => - Seq(p.x, p.y) + val output = new Array[Any](2) + output(0) = p.x + output(1) = p.y + new GenericArrayData(output) } } override def deserialize(datum: Any): ExamplePoint = { datum match { - case values: Seq[_] => - val xy = values.asInstanceOf[Seq[Double]] - assert(xy.length == 2) - new ExamplePoint(xy(0), xy(1)) - case values: util.ArrayList[_] => - val xy = values.asInstanceOf[util.ArrayList[Double]].asScala - new ExamplePoint(xy(0), xy(1)) + case values: ArrayData => + new ExamplePoint(values.getDouble(0), values.getDouble(1)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala new file mode 100644 index 0000000000000..909a8abd225b8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.concurrent.locks.ReentrantReadWriteLock +import scala.collection.mutable.ListBuffer + +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.Logging +import org.apache.spark.sql.execution.QueryExecution + + +/** + * The interface of query execution listener that can be used to analyze execution metrics. + * + * Note that implementations should guarantee thread-safety as they will be used in a non + * thread-safe way. + */ +@Experimental +trait QueryExecutionListener { + + /** + * A callback function that will be called when a query executed successfully. + * Implementations should guarantee thread-safe. + * + * @param funcName the name of the action that triggered this query. + * @param qe the QueryExecution object that carries detail information like logical plan, + * physical plan, etc. + * @param duration the execution time for this query in nanoseconds. + */ + @DeveloperApi + def onSuccess(funcName: String, qe: QueryExecution, duration: Long) + + /** + * A callback function that will be called when a query execution failed. + * Implementations should guarantee thread-safe. + * + * @param funcName the name of the action that triggered this query. + * @param qe the QueryExecution object that carries detail information like logical plan, + * physical plan, etc. + * @param exception the exception that failed this query. + */ + @DeveloperApi + def onFailure(funcName: String, qe: QueryExecution, exception: Exception) +} + +@Experimental +class ExecutionListenerManager extends Logging { + private[this] val listeners = ListBuffer.empty[QueryExecutionListener] + private[this] val lock = new ReentrantReadWriteLock() + + /** Acquires a read lock on the cache for the duration of `f`. */ + private def readLock[A](f: => A): A = { + val rl = lock.readLock() + rl.lock() + try f finally { + rl.unlock() + } + } + + /** Acquires a write lock on the cache for the duration of `f`. */ + private def writeLock[A](f: => A): A = { + val wl = lock.writeLock() + wl.lock() + try f finally { + wl.unlock() + } + } + + /** + * Registers the specified QueryExecutionListener. + */ + @DeveloperApi + def register(listener: QueryExecutionListener): Unit = writeLock { + listeners += listener + } + + /** + * Unregisters the specified QueryExecutionListener. + */ + @DeveloperApi + def unregister(listener: QueryExecutionListener): Unit = writeLock { + listeners -= listener + } + + /** + * clears out all registered QueryExecutionListeners. + */ + @DeveloperApi + def clear(): Unit = writeLock { + listeners.clear() + } + + private[sql] def onSuccess( + funcName: String, + qe: QueryExecution, + duration: Long): Unit = readLock { + withErrorHandling { listener => + listener.onSuccess(funcName, qe, duration) + } + } + + private[sql] def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception): Unit = readLock { + withErrorHandling { listener => + listener.onFailure(funcName, qe, exception) + } + } + + private def withErrorHandling(f: QueryExecutionListener => Unit): Unit = { + for (listener <- listeners) { + try { + f(listener) + } catch { + case e: Exception => logWarning("error executing query execution listener", e) + } + } + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index bf693c7c393f6..7b50aad4ad498 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -83,7 +84,7 @@ public void setAge(int age) { @Test public void applySchema() { - List personList = new ArrayList(2); + List personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); @@ -95,12 +96,13 @@ public void applySchema() { JavaRDD rowRDD = javaCtx.parallelize(personList).map( new Function() { + @Override public Row call(Person person) throws Exception { return RowFactory.create(person.getName(), person.getAge()); } }); - List fields = new ArrayList(2); + List fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); @@ -118,7 +120,7 @@ public Row call(Person person) throws Exception { @Test public void dataFrameRDDOperations() { - List personList = new ArrayList(2); + List personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); @@ -129,27 +131,28 @@ public void dataFrameRDDOperations() { personList.add(person2); JavaRDD rowRDD = javaCtx.parallelize(personList).map( - new Function() { - public Row call(Person person) throws Exception { - return RowFactory.create(person.getName(), person.getAge()); - } - }); - - List fields = new ArrayList(2); - fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); + new Function() { + @Override + public Row call(Person person) { + return RowFactory.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList<>(2); + fields.add(DataTypes.createStructField("", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); DataFrame df = sqlContext.applySchema(rowRDD, schema); df.registerTempTable("people"); List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() { - + @Override public String call(Row row) { - return row.getString(0) + "_" + row.get(1).toString(); + return row.getString(0) + "_" + row.get(1); } }).collect(); - List expected = new ArrayList(2); + List expected = new ArrayList<>(2); expected.add("Michael_29"); expected.add("Yin_28"); @@ -165,7 +168,7 @@ public void applySchemaToJSON() { "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); - List fields = new ArrayList(7); + List fields = new ArrayList<>(7); fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(20, 0), true)); fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); @@ -175,10 +178,10 @@ public void applySchemaToJSON() { fields.add(DataTypes.createStructField("null", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("string", DataTypes.StringType, true)); StructType expectedSchema = DataTypes.createStructType(fields); - List expectedResult = new ArrayList(2); + List expectedResult = new ArrayList<>(2); expectedResult.add( RowFactory.create( - new java.math.BigDecimal("92233720368547758070"), + new BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, @@ -187,7 +190,7 @@ public void applySchemaToJSON() { "this is a simple string.")); expectedResult.add( RowFactory.create( - new java.math.BigDecimal("92233720368547758069"), + new BigDecimal("92233720368547758069"), false, 1.7976931348623157E305, 11, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 4867cebf5328c..a1a3fdbb486ea 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -37,6 +37,7 @@ import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -61,7 +62,7 @@ public void tearDown() { @Test public void testExecution() { DataFrame df = context.table("testData").filter("key = 1"); - Assert.assertEquals(df.select("key").collect()[0].get(0), 1); + Assert.assertEquals(1, df.select("key").collect()[0].get(0)); } /** @@ -90,6 +91,7 @@ public void testVarargMethods() { df.groupBy().mean("key"); df.groupBy().max("key"); df.groupBy().min("key"); + df.groupBy().stddev("key"); df.groupBy().sum("key"); // Varargs in column expressions @@ -119,7 +121,7 @@ public void testShow() { public static class Bean implements Serializable { private double a = 0.0; - private Integer[] b = new Integer[]{0, 1}; + private Integer[] b = { 0, 1 }; private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); private List d = Arrays.asList("floppy", "disk"); @@ -140,11 +142,7 @@ public List getD() { } } - @Test - public void testCreateDataFrameFromJavaBeans() { - Bean bean = new Bean(); - JavaRDD rdd = jsc.parallelize(Arrays.asList(bean)); - DataFrame df = context.createDataFrame(rdd, Bean.class); + void validateDataFrameWithBeans(Bean bean, DataFrame df) { StructType schema = df.schema(); Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()), schema.apply("a")); @@ -161,7 +159,7 @@ public void testCreateDataFrameFromJavaBeans() { schema.apply("d")); Row first = df.select("a", "b", "c", "d").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); - // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, + // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. Seq result = first.getAs(1); Assert.assertEquals(bean.getB().length, result.length()); @@ -180,7 +178,33 @@ public void testCreateDataFrameFromJavaBeans() { } } - private static Comparator CrosstabRowComparator = new Comparator() { + @Test + public void testCreateDataFrameFromLocalJavaBeans() { + Bean bean = new Bean(); + List data = Arrays.asList(bean); + DataFrame df = context.createDataFrame(data, Bean.class); + validateDataFrameWithBeans(bean, df); + } + + @Test + public void testCreateDataFrameFromJavaBeans() { + Bean bean = new Bean(); + JavaRDD rdd = jsc.parallelize(Arrays.asList(bean)); + DataFrame df = context.createDataFrame(rdd, Bean.class); + validateDataFrameWithBeans(bean, df); + } + + @Test + public void testCreateDataFromFromList() { + StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); + List rows = Arrays.asList(RowFactory.create(0)); + DataFrame df = context.createDataFrame(rows, schema); + Row[] result = df.collect(); + Assert.assertEquals(1, result.length); + } + + private static final Comparator crosstabRowComparator = new Comparator() { + @Override public int compare(Row row1, Row row2) { String item1 = row1.getString(0); String item2 = row2.getString(0); @@ -193,16 +217,16 @@ public void testCrosstab() { DataFrame df = context.table("testData2"); DataFrame crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); - Assert.assertEquals(columnNames[0], "a_b"); - Assert.assertEquals(columnNames[1], "1"); - Assert.assertEquals(columnNames[2], "2"); + Assert.assertEquals("a_b", columnNames[0]); + Assert.assertEquals("1", columnNames[1]); + Assert.assertEquals("2", columnNames[2]); Row[] rows = crosstab.collect(); - Arrays.sort(rows, CrosstabRowComparator); + Arrays.sort(rows, crosstabRowComparator); Integer count = 1; for (Row row : rows) { Assert.assertEquals(row.get(0).toString(), count.toString()); - Assert.assertEquals(row.getLong(1), 1L); - Assert.assertEquals(row.getLong(2), 1L); + Assert.assertEquals(1L, row.getLong(1)); + Assert.assertEquals(1L, row.getLong(2)); count++; } } @@ -210,7 +234,7 @@ public void testCrosstab() { @Test public void testFrequentItems() { DataFrame df = context.table("testData2"); - String[] cols = new String[]{"a"}; + String[] cols = {"a"}; DataFrame results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); } @@ -219,14 +243,14 @@ public void testFrequentItems() { public void testCorrelation() { DataFrame df = context.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); - Assert.assertTrue(Math.abs(pearsonCorr) < 1e-6); + Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { DataFrame df = context.table("testData2"); Double result = df.stat().cov("a", "b"); - Assert.assertTrue(Math.abs(result) < 1e-6); + Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test @@ -234,7 +258,7 @@ public void testSampleBy() { DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); - Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; + Row[] expected = {RowFactory.create(0, 5), RowFactory.create(1, 8)}; Assert.assertArrayEquals(expected, actual); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index 4ce1d1dddb26a..3ab4db2a035d3 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql; import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; import java.sql.Date; import java.sql.Timestamp; import java.util.Arrays; @@ -52,12 +53,12 @@ public void setUp() { shortValue = (short)32767; intValue = 2147483647; longValue = 9223372036854775807L; - floatValue = (float)3.4028235E38; + floatValue = 3.4028235E38f; doubleValue = 1.7976931348623157E308; decimalValue = new BigDecimal("1.7976931348623157E328"); booleanValue = true; stringValue = "this is a string"; - binaryValue = stringValue.getBytes(); + binaryValue = stringValue.getBytes(StandardCharsets.UTF_8); dateValue = Date.valueOf("2014-06-30"); timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0"); } @@ -123,8 +124,8 @@ public void constructSimpleRow() { Assert.assertEquals(binaryValue, simpleRow.get(16)); Assert.assertEquals(dateValue, simpleRow.get(17)); Assert.assertEquals(timestampValue, simpleRow.get(18)); - Assert.assertEquals(true, simpleRow.isNullAt(19)); - Assert.assertEquals(null, simpleRow.get(19)); + Assert.assertTrue(simpleRow.isNullAt(19)); + Assert.assertNull(simpleRow.get(19)); } @Test @@ -134,7 +135,7 @@ public void constructComplexRow() { stringValue + " (1)", stringValue + " (2)", stringValue + "(3)"); // Simple map - Map simpleMap = new HashMap(); + Map simpleMap = new HashMap<>(); simpleMap.put(stringValue + " (1)", longValue); simpleMap.put(stringValue + " (2)", longValue - 1); simpleMap.put(stringValue + " (3)", longValue - 2); @@ -149,7 +150,7 @@ public void constructComplexRow() { List arrayOfRows = Arrays.asList(simpleStruct); // Complex map - Map, Row> complexMap = new HashMap, Row>(); + Map, Row> complexMap = new HashMap<>(); complexMap.put(arrayOfRows, simpleStruct); // Complex struct @@ -167,7 +168,7 @@ public void constructComplexRow() { Assert.assertEquals(arrayOfMaps, complexStruct.get(3)); Assert.assertEquals(arrayOfRows, complexStruct.get(4)); Assert.assertEquals(complexMap, complexStruct.get(5)); - Assert.assertEquals(null, complexStruct.get(6)); + Assert.assertNull(complexStruct.get(6)); // A very complex row Row complexRow = RowFactory.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index bb02b58cca9be..4a78dca7fea66 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -20,6 +20,7 @@ import java.io.Serializable; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -61,13 +62,13 @@ public void udf1Test() { sqlContext.udf().register("stringLengthTest", new UDF1() { @Override - public Integer call(String str) throws Exception { + public Integer call(String str) { return str.length(); } }, DataTypes.IntegerType); Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); - assert(result.getInt(0) == 4); + Assert.assertEquals(4, result.getInt(0)); } @SuppressWarnings("unchecked") @@ -81,12 +82,12 @@ public void udf2Test() { sqlContext.udf().register("stringLengthTest", new UDF2() { @Override - public Integer call(String str1, String str2) throws Exception { + public Integer call(String str1, String str2) { return str1.length() + str2.length(); } }, DataTypes.IntegerType); Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); - assert(result.getInt(0) == 9); + Assert.assertEquals(9, result.getInt(0)); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 6f9e7f68dc39c..9e241f20987c0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -44,7 +44,7 @@ public class JavaSaveLoadSuite { File path; DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -64,7 +64,7 @@ public void setUp() throws IOException { path.delete(); } - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } @@ -82,7 +82,7 @@ public void tearDown() { @Test public void saveAndLoad() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); @@ -91,11 +91,11 @@ public void saveAndLoad() { @Test public void saveAndLoadWithSchema() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write().format("json").mode(SaveMode.ErrorIfExists).options(options).save(); - List fields = new ArrayList(); + List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); diff --git a/sql/core/src/test/resources/dec-in-i32.parquet b/sql/core/src/test/resources/dec-in-i32.parquet new file mode 100755 index 0000000000000..bb5d4af8dd368 Binary files /dev/null and b/sql/core/src/test/resources/dec-in-i32.parquet differ diff --git a/sql/core/src/test/resources/dec-in-i64.parquet b/sql/core/src/test/resources/dec-in-i64.parquet new file mode 100755 index 0000000000000..e07c4a0ad9843 Binary files /dev/null and b/sql/core/src/test/resources/dec-in-i64.parquet differ diff --git a/sql/core/src/test/resources/text-suite.txt b/sql/core/src/test/resources/text-suite.txt new file mode 100644 index 0000000000000..e8fd967197fe8 --- /dev/null +++ b/sql/core/src/test/resources/text-suite.txt @@ -0,0 +1,4 @@ +This is a test file for the text data source +1+1 +数据砖头 +"doh" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index af7590c3d3c17..fd566c8276bc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.execution.PhysicalRDD + import scala.concurrent.duration._ import scala.language.postfixOps @@ -34,7 +37,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { import testImplicits._ def rddIdOf(tableName: String): Int = { - val executedPlan = ctx.table(tableName).queryExecution.executedPlan + val executedPlan = sqlContext.table(tableName).queryExecution.executedPlan executedPlan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id @@ -44,7 +47,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } def isMaterialized(rddId: Int): Boolean = { - ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } test("withColumn doesn't invalidate cached dataframe") { @@ -69,41 +72,41 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) - ctx.cacheTable("tempTable") + sqlContext.cacheTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - ctx.uncacheTable("tempTable") + sqlContext.uncacheTable("tempTable") } test("unpersist an uncached table will not raise exception") { - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != ctx.cacheManager.lookupCachedData(testData)) + assert(None != sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) } test("cache table as select") { sql("CACHE TABLE tempTable AS SELECT key FROM testData") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - ctx.uncacheTable("tempTable") + sqlContext.uncacheTable("tempTable") } test("uncaching temp table") { testData.select('key).registerTempTable("tempTable1") testData.select('key).registerTempTable("tempTable2") - ctx.cacheTable("tempTable1") + sqlContext.cacheTable("tempTable1") assertCached(sql("SELECT COUNT(*) FROM tempTable1")) assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? - ctx.uncacheTable("tempTable2") + sqlContext.uncacheTable("tempTable2") // Should this be cached? assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) @@ -111,103 +114,103 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("too big for memory") { val data = "*" * 1000 - ctx.sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() .registerTempTable("bigData") - ctx.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(ctx.table("bigData").count() === 200000L) - ctx.table("bigData").unpersist(blocking = true) + sqlContext.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(sqlContext.table("bigData").count() === 200000L) + sqlContext.table("bigData").unpersist(blocking = true) } test("calling .cache() should use in-memory columnar caching") { - ctx.table("testData").cache() - assertCached(ctx.table("testData")) - ctx.table("testData").unpersist(blocking = true) + sqlContext.table("testData").cache() + assertCached(sqlContext.table("testData")) + sqlContext.table("testData").unpersist(blocking = true) } test("calling .unpersist() should drop in-memory columnar cache") { - ctx.table("testData").cache() - ctx.table("testData").count() - ctx.table("testData").unpersist(blocking = true) - assertCached(ctx.table("testData"), 0) + sqlContext.table("testData").cache() + sqlContext.table("testData").count() + sqlContext.table("testData").unpersist(blocking = true) + assertCached(sqlContext.table("testData"), 0) } test("isCached") { - ctx.cacheTable("testData") + sqlContext.cacheTable("testData") - assertCached(ctx.table("testData")) - assert(ctx.table("testData").queryExecution.withCachedData match { + assertCached(sqlContext.table("testData")) + assert(sqlContext.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => true case _ => false }) - ctx.uncacheTable("testData") - assert(!ctx.isCached("testData")) - assert(ctx.table("testData").queryExecution.withCachedData match { + sqlContext.uncacheTable("testData") + assert(!sqlContext.isCached("testData")) + assert(sqlContext.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assume(!ctx.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!sqlContext.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) - ctx.cacheTable("testData") - assertCached(ctx.table("testData")) + sqlContext.cacheTable("testData") + assertCached(sqlContext.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - ctx.table("testData").queryExecution.withCachedData.collect { + sqlContext.table("testData").queryExecution.withCachedData.collect { case r: InMemoryRelation => r }.size } - ctx.cacheTable("testData") + sqlContext.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { - ctx.table("testData").queryExecution.withCachedData.collect { + sqlContext.table("testData").queryExecution.withCachedData.collect { case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan, _) => r }.size } - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") } test("read from cached table and uncache") { - ctx.cacheTable("testData") - checkAnswer(ctx.table("testData"), testData.collect().toSeq) - assertCached(ctx.table("testData")) + sqlContext.cacheTable("testData") + checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) + assertCached(sqlContext.table("testData")) - ctx.uncacheTable("testData") - checkAnswer(ctx.table("testData"), testData.collect().toSeq) - assertCached(ctx.table("testData"), 0) + sqlContext.uncacheTable("testData") + checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) + assertCached(sqlContext.table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") } } test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") - ctx.cacheTable("selectStar") + sqlContext.cacheTable("selectStar") checkAnswer( sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) - ctx.uncacheTable("selectStar") + sqlContext.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - ctx.cacheTable("testData") + sqlContext.cacheTable("testData") checkAnswer( sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") - assertCached(ctx.table("testData")) + assertCached(sqlContext.table("testData")) val rddId = rddIdOf("testData") assert( @@ -215,7 +218,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assert(!ctx.isCached("testData"), "Table 'testData' should not be cached") + assert(!sqlContext.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") @@ -224,14 +227,14 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assertCached(ctx.table("testCacheTable")) + assertCached(sqlContext.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - ctx.uncacheTable("testCacheTable") + sqlContext.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -239,14 +242,14 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("CACHE TABLE tableName AS SELECT ...") { sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") - assertCached(ctx.table("testCacheTable")) + assertCached(sqlContext.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - ctx.uncacheTable("testCacheTable") + sqlContext.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -254,7 +257,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") - assertCached(ctx.table("testData")) + assertCached(sqlContext.table("testData")) val rddId = rddIdOf("testData") assert( @@ -266,7 +269,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -274,7 +277,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("InMemoryRelation statistics") { sql("CACHE TABLE testData") - ctx.table("testData").queryExecution.withCachedData.collect { + sqlContext.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) @@ -283,46 +286,46 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("Drops temporary table") { testData.select('key).registerTempTable("t1") - ctx.table("t1") - ctx.dropTempTable("t1") - assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) + sqlContext.table("t1") + sqlContext.dropTempTable("t1") + intercept[NoSuchTableException](sqlContext.table("t1")) } test("Drops cached temporary table") { testData.select('key).registerTempTable("t1") testData.select('key).registerTempTable("t2") - ctx.cacheTable("t1") + sqlContext.cacheTable("t1") - assert(ctx.isCached("t1")) - assert(ctx.isCached("t2")) + assert(sqlContext.isCached("t1")) + assert(sqlContext.isCached("t2")) - ctx.dropTempTable("t1") - assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) - assert(!ctx.isCached("t2")) + sqlContext.dropTempTable("t1") + intercept[NoSuchTableException](sqlContext.table("t1")) + assert(!sqlContext.isCached("t2")) } test("Clear all cache") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - ctx.cacheTable("t1") - ctx.cacheTable("t2") - ctx.clearCache() - assert(ctx.cacheManager.isEmpty) + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + sqlContext.clearCache() + assert(sqlContext.cacheManager.isEmpty) sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - ctx.cacheTable("t1") - ctx.cacheTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") sql("Clear CACHE") - assert(ctx.cacheManager.isEmpty) + assert(sqlContext.cacheManager.isEmpty) } test("Clear accumulators when uncacheTable to prevent memory leaking") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - ctx.cacheTable("t1") - ctx.cacheTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() @@ -331,9 +334,23 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { Accumulators.synchronized { val accsSize = Accumulators.originals.size - ctx.uncacheTable("t1") - ctx.uncacheTable("t2") + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") assert((accsSize - 2) == Accumulators.originals.size) } } + + test("SPARK-10327 Cache Table is not working while subquery has alias in its project list") { + sparkContext.parallelize((1, 1) :: (2, 2) :: Nil) + .toDF("key", "value").selectExpr("key", "value", "key+1").registerTempTable("abc") + sqlContext.cacheTable("abc") + + val sparkPlan = sql( + """select a.key, b.key, c.key from + |abc a join abc b on a.key=b.key + |join abc c on a.key=c.key""".stripMargin).queryExecution.sparkPlan + + assert(sparkPlan.collect { case e: InMemoryColumnarTableScan => e }.size === 3) + assert(sparkPlan.collect { case e: PhysicalRDD => e }.size === 0) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 37738ec5b3c1d..fa559c9c64005 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -29,7 +29,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { import testImplicits._ private lazy val booleanData = { - ctx.createDataFrame(ctx.sparkContext.parallelize( + sqlContext.createDataFrame(sparkContext.parallelize( Row(false, false) :: Row(false, true) :: Row(true, false) :: @@ -286,7 +286,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("isNaN") { - val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( + val testData = sqlContext.createDataFrame(sparkContext.parallelize( Row(Double.NaN, Float.NaN) :: Row(math.log(-1), math.log(-3).toFloat) :: Row(null, null) :: @@ -307,7 +307,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("nanvl") { - val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( + val testData = sqlContext.createDataFrame(sparkContext.parallelize( Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil), StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType), StructField("c", DoubleType), StructField("d", DoubleType), @@ -350,7 +350,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("!==") { - val nullData = ctx.createDataFrame(ctx.sparkContext.parallelize( + val nullData = sqlContext.createDataFrame(sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -411,7 +411,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("between") { - val testData = ctx.sparkContext.parallelize( + val testData = sparkContext.parallelize( (0, 1, 2) :: (1, 2, 3) :: (2, 1, 0) :: @@ -556,7 +556,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { test("monotonicallyIncreasingId") { // Make sure we have 2 partitions, each with 2 records. - val df = ctx.sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => + val df = sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( @@ -567,7 +567,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { test("sparkPartitionId") { // Make sure we have 2 partitions, each with 2 records. - val df = ctx.sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => + val df = sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( @@ -578,7 +578,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { test("InputFileName") { withTempPath { dir => - val data = sqlContext.sparkContext.parallelize(0 to 10).toDF("id") + val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName()) .head.getString(0) @@ -588,12 +588,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } } - test("lift alias out of cast") { - compareExpressions( - col("1234").as("name").cast("int").expr, - col("1234").cast("int").as("name").expr) - } - test("columns can be compared") { assert('key.desc == 'key.desc) assert('key.desc != 'key.asc) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 72cf7aab0b977..f5ef9ffd7f4f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -66,12 +66,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) + sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) + sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) } test("agg without groups") { @@ -175,6 +175,39 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(0, null)) } + test("stddev") { + val testData2ADev = math.sqrt(4/5.0) + + checkAnswer( + testData2.agg(stddev('a)), + Row(testData2ADev)) + + checkAnswer( + testData2.agg(stddev_pop('a)), + Row(math.sqrt(4/6.0))) + + checkAnswer( + testData2.agg(stddev_samp('a)), + Row(testData2ADev)) + } + + test("zero stddev") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + assert(emptyTableData.count() == 0) + + checkAnswer( + emptyTableData.agg(stddev('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(stddev_pop('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(stddev_samp('a)), + Row(null)) + } + test("zero sum") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 3c359dd840ab7..09f7b507670c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -28,19 +28,19 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { test("UDF on struct") { val f = udf((a: String) => a) - val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") df.select(struct($"a").as("s")).select(f($"s.a")).collect() } test("UDF on named_struct") { val f = udf((a: String) => a) - val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") df.selectExpr("named_struct('a', a) s").select(f($"s.a")).collect() } test("UDF on array") { val f = udf((a: String) => a) - val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index e5d7d63441a6b..094efbaeadcd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -24,7 +24,7 @@ class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { test("RDD of tuples") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), (1 to 10).map(i => Row(i, i.toString))) } @@ -36,19 +36,19 @@ class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { test("RDD[Int]") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).toDF("intCol"), + sparkContext.parallelize(1 to 10).toDF("intCol"), (1 to 10).map(i => Row(i))) } test("RDD[Long]") { checkAnswer( - ctx.sparkContext.parallelize(1L to 10L).toDF("longCol"), + sparkContext.parallelize(1L to 10L).toDF("longCol"), (1L to 10L).map(i => Row(i))) } test("RDD[String]") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), + sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e2716d7841d85..56ad71ea4f487 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -42,6 +42,19 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) } + test("join - join using multiple columns and specifying join type") { + val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str") + val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str") + + checkAnswer( + df.join(df2, Seq("int", "str"), "left"), + Row(1, 2, "1", null) :: Row(2, 3, "2", null) :: Row(3, 4, "3", null) :: Nil) + + checkAnswer( + df.join(df2, Seq("int", "str"), "right"), + Row(null, null, null, 2) :: Row(null, null, null, 3) :: Row(null, null, null, 4) :: Nil) + } + test("join - join using self join") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 329ffb66083b1..e34875471f093 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -141,24 +141,26 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { } test("fill with map") { - val df = Seq[(String, String, java.lang.Long, java.lang.Double)]( - (null, null, null, null)).toDF("a", "b", "c", "d") + val df = Seq[(String, String, java.lang.Long, java.lang.Double, java.lang.Boolean)]( + (null, null, null, null, null)).toDF("a", "b", "c", "d", "e") checkAnswer( df.na.fill(Map( "a" -> "test", "c" -> 1, - "d" -> 2.2 + "d" -> 2.2, + "e" -> false )), - Row("test", null, 1, 2.2)) + Row("test", null, 1, 2.2, false)) // Test Java version checkAnswer( df.na.fill(Map( "a" -> "test", "c" -> 1, - "d" -> 2.2 + "d" -> 2.2, + "e" -> false ).asJava), - Row("test", null, 1, 2.2)) + Row("test", null, 1, 2.2, false)) } test("replace") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 28bdd6f83b687..6524abcf5e97f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -29,7 +29,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { test("sample with replacement") { val n = 100 - val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = true, 0.05, seed = 13), Seq(5, 10, 52, 73).map(Row(_)) @@ -38,7 +38,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { test("sample without replacement") { val n = 100 - val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = false, 0.05, seed = 13), Seq(16, 23, 88, 100).map(Row(_)) @@ -47,7 +47,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { test("randomSplit") { val n = 600 - val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -164,7 +164,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("Frequent Items 2") { - val rows = ctx.sparkContext.parallelize(Seq.empty[Int], 4) + val rows = sparkContext.parallelize(Seq.empty[Int], 4) // this is a regression test, where when merging partitions, we omitted values with higher // counts than those that existed in the map when the map was full. This test should also fail // if anything like SPARK-9614 is observed once again @@ -182,7 +182,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("sampleBy") { - val df = ctx.range(0, 100).select((col("id") % 3).as("key")) + val df = sqlContext.range(0, 100).select((col("id") % 3).as("key")) val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 284fff184085a..f4c7aa34e560c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -22,6 +22,8 @@ import java.io.File import scala.language.postfixOps import scala.util.Random +import org.scalatest.Matchers._ + import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -345,7 +347,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("replace column using withColumn") { - val df2 = sqlContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( df3.select("x"), @@ -386,13 +388,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("key", "value")) } - test("drop unknown column with same name (no-op) with column reference") { + test("drop unknown column with same name with column reference") { val col = Column("key") val df = testData.drop(col) checkAnswer( df, - testData.collect().toSeq) - assert(df.schema.map(_.name) === Seq("key", "value")) + testData.collect().map(x => Row(x.getString(1))).toSeq) + assert(df.schema.map(_.name) === Seq("value")) } test("drop column after join with duplicate columns using column reference") { @@ -434,7 +436,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val describeResult = Seq( Row("count", "4", "4"), Row("mean", "33.0", "178.0"), - Row("stddev", "16.583123951777", "10.0"), + Row("stddev", "19.148542155126762", "11.547005383792516"), Row("min", "16", "164"), Row("max", "60", "192")) @@ -506,7 +508,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("showString: truncate = [true, false]") { val longString = Array.fill(21)("1").mkString - val df = sqlContext.sparkContext.parallelize(Seq("1", longString)).toDF() + val df = sparkContext.parallelize(Seq("1", longString)).toDF() val expectedAnswerForFalse = """+---------------------+ ||_1 | |+---------------------+ @@ -596,7 +598,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { - val rowRDD = sqlContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) val df = sqlContext.createDataFrame(rowRDD, schema) df.rdd.collect() @@ -619,14 +621,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + val df = sqlContext.read.json(sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + val df2 = sqlContext.read.json(sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -646,7 +648,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-7324 dropDuplicates") { - val testData = sqlContext.sparkContext.parallelize( + val testData = sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: (1, 2, 1) :: (2, 1, 2) :: (2, 2, 2) :: (2, 2, 1) :: @@ -869,7 +871,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-9323: DataFrame.orderBy should support nested column name") { - val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + val df = sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": 1}}""" :: Nil)) checkAnswer(df.orderBy("a.b"), Row(Row(1))) } @@ -887,4 +889,95 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .select(struct($"b")) .collect() } + + test("SPARK-10185: Read multiple Hadoop Filesystem paths and paths with a comma in it") { + withTempDir { dir => + val df1 = Seq((1, 22)).toDF("a", "b") + val dir1 = new File(dir, "dir,1").getCanonicalPath + df1.write.format("json").save(dir1) + + val df2 = Seq((2, 23)).toDF("a", "b") + val dir2 = new File(dir, "dir2").getCanonicalPath + df2.write.format("json").save(dir2) + + checkAnswer(sqlContext.read.format("json").load(Array(dir1, dir2)), + Row(1, 22) :: Row(2, 23) :: Nil) + + checkAnswer(sqlContext.read.format("json").load(dir1), + Row(1, 22) :: Nil) + } + } + + test("SPARK-10034: Sort on Aggregate with aggregation expression named 'aggOrdering'") { + val df = Seq(1 -> 2).toDF("i", "j") + val query = df.groupBy('i) + .agg(max('j).as("aggOrdering")) + .orderBy(sum('j)) + checkAnswer(query, Row(1, 2)) + } + + test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { + val input = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + (1 to 10).map(i => s"""{"id": $i}"""))) + + val df = input.select($"id", rand(0).as('r)) + df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row => + assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) + } + } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { + val df1 = (1 to 100).map(Tuple1.apply).toDF("i") + val df2 = (1 to 30).map(Tuple1.apply).toDF("i") + val intersect = df1.intersect(df2) + val except = df1.except(df2) + assert(intersect.count() === 30) + assert(except.count() === 70) + } + + test("SPARK-10740: handle nondeterministic expressions correctly for set operations") { + val df1 = (1 to 20).map(Tuple1.apply).toDF("i") + val df2 = (1 to 10).map(Tuple1.apply).toDF("i") + + // When generating expected results at here, we need to follow the implementation of + // Rand expression. + def expected(df: DataFrame): Seq[Row] = { + df.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.filter(_.getInt(0) < rng.nextDouble() * 10) + } + } + + val union = df1.unionAll(df2) + checkAnswer( + union.filter('i < rand(7) * 10), + expected(union) + ) + checkAnswer( + union.select(rand(7)), + union.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.map(_ => rng.nextDouble()).map(i => Row(i)) + } + ) + + val intersect = df1.intersect(df2) + checkAnswer( + intersect.filter('i < rand(7) * 10), + expected(intersect) + ) + + val except = df1.except(df2) + checkAnswer( + except.filter('i < rand(7) * 10), + expected(except) + ) + } + + test("SPARK-10743: keep the name of expression if possible when do cast") { + val df = (1 to 10).map(Tuple1.apply).toDF("i").as("src") + assert(df.select($"src.i".cast(StringType)).columns.head === "i") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index 77907e91363ec..7ae12a7895f7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -32,7 +32,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { test("test simple types") { withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { - val df = sqlContext.sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") + val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) } } @@ -40,7 +40,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { test("test struct type") { withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { val struct = Row(1, 2L, 3.0F, 3.0) - val data = sqlContext.sparkContext.parallelize(Seq(Row(1, struct))) + val data = sparkContext.parallelize(Seq(Row(1, struct))) val schema = new StructType() .add("a", IntegerType) @@ -60,7 +60,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { val innerStruct = Row(1, "abcd") val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") - val data = sqlContext.sparkContext.parallelize(Seq(Row(1, outerStruct))) + val data = sparkContext.parallelize(Seq(Row(1, outerStruct))) val schema = new StructType() .add("a", IntegerType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala new file mode 100644 index 0000000000000..32443557fb8e0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.test.SharedSQLContext + +case class IntClass(value: Int) + +class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("toDS") { + val data = Seq(1, 2, 3, 4, 5, 6) + checkAnswer( + data.toDS(), + data: _*) + } + + test("as case class / collect") { + val ds = Seq(1, 2, 3).toDS().as[IntClass] + checkAnswer( + ds, + IntClass(1), IntClass(2), IntClass(3)) + + assert(ds.collect().head == IntClass(1)) + } + + test("map") { + val ds = Seq(1, 2, 3).toDS() + checkAnswer( + ds.map(_ + 1), + 2, 3, 4) + } + + test("filter") { + val ds = Seq(1, 2, 3, 4).toDS() + checkAnswer( + ds.filter(_ % 2 == 0), + 2, 4) + } + + test("foreach") { + val ds = Seq(1, 2, 3).toDS() + val acc = sparkContext.accumulator(0) + ds.foreach(acc +=) + assert(acc.value == 6) + } + + test("foreachPartition") { + val ds = Seq(1, 2, 3).toDS() + val acc = sparkContext.accumulator(0) + ds.foreachPartition(_.foreach(acc +=)) + assert(acc.value == 6) + } + + test("reduce") { + val ds = Seq(1, 2, 3).toDS() + assert(ds.reduce(_ + _) == 6) + } + + test("fold") { + val ds = Seq(1, 2, 3).toDS() + assert(ds.fold(0)(_ + _) == 6) + } + + test("groupBy function, keys") { + val ds = Seq(1, 2, 3, 4, 5).toDS() + val grouped = ds.groupBy(_ % 2) + checkAnswer( + grouped.keys, + 0, 1) + } + + test("groupBy function, mapGroups") { + val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() + val grouped = ds.groupBy(_ % 2) + val agged = grouped.mapGroups { case (g, iter) => + val name = if (g == 0) "even" else "odd" + Iterator((name, iter.size)) + } + + checkAnswer( + agged, + ("even", 5), ("odd", 6)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala new file mode 100644 index 0000000000000..08496249c60cc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +case class ClassData(a: String, b: Int) + +class DatasetSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("toDS") { + val data = Seq(("a", 1) , ("b", 2), ("c", 3)) + checkAnswer( + data.toDS(), + data: _*) + } + + test("as case class / collect") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData] + checkAnswer( + ds, + ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) + assert(ds.collect().head == ClassData("a", 1)) + } + + test("as case class - reordered fields by name") { + val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData] + assert(ds.collect() === Array(ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))) + } + + test("map") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.map(v => (v._1, v._2 + 1)), + ("a", 2), ("b", 3), ("c", 4)) + } + + test("select") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.select(expr("_2 + 1").as[Int]), + 2, 3, 4) + } + + test("select 3") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.select( + expr("_1").as[String], + expr("_2").as[Int], + expr("_2 + 1").as[Int]), + ("a", 1, 2), ("b", 2, 3), ("c", 3, 4)) + } + + test("filter") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.filter(_._1 == "b"), + ("b", 2)) + } + + test("foreach") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val acc = sparkContext.accumulator(0) + ds.foreach(v => acc += v._2) + assert(acc.value == 6) + } + + test("foreachPartition") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val acc = sparkContext.accumulator(0) + ds.foreachPartition(_.foreach(v => acc += v._2)) + assert(acc.value == 6) + } + + test("reduce") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) + } + + test("fold") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) + } + + test("groupBy function, keys") { + val ds = Seq(("a", 1), ("b", 1)).toDS() + val grouped = ds.groupBy(v => (1, v._2)) + checkAnswer( + grouped.keys, + (1, 1)) + } + + test("groupBy function, mapGroups") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy(v => (v._1, "word")) + val agged = grouped.mapGroups { case (g, iter) => + Iterator((g._1, iter.map(_._2).sum)) + } + + checkAnswer( + agged, + ("a", 30), ("b", 3), ("c", 1)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index 8d2f45d70308b..78a98798eff64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -52,7 +52,7 @@ class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { try { sqlContext.experimental.extraStrategies = TestStrategy :: Nil - val df = sqlContext.sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") + val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") checkAnswer( df.select("a"), Row("so fast")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index f5c5046a8ed88..b1fb06815868c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.SharedSQLContext @@ -31,7 +32,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = ctx.planner.EquiJoinSelection(join) + val planned = sqlContext.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -59,7 +60,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("join operator selection") { - ctx.cacheManager.clearCache() + sqlContext.cacheManager.clearCache() Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), @@ -83,7 +84,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData full outer join testData2 ON key = a", - classOf[ShuffledHashOuterJoin]), + classOf[SortMergeOuterJoin]), ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", @@ -118,7 +119,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted hash join operator selection") { - ctx.cacheManager.clearCache() + sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") for (sortMergeJoinEnabled <- Seq(true, false)) { withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") { @@ -138,7 +139,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted hash outer join operator selection") { - ctx.cacheManager.clearCache() + sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( @@ -167,7 +168,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = ctx.planner.EquiJoinSelection(join) + val planned = sqlContext.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -359,8 +360,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { upperCaseData.where('N <= 4).registerTempTable("left") upperCaseData.where('N >= 3).registerTempTable("right") - val left = UnresolvedRelation(Seq("left"), None) - val right = UnresolvedRelation(Seq("right"), None) + val left = UnresolvedRelation(TableIdentifier("left"), None) + val right = UnresolvedRelation(TableIdentifier("right"), None) checkAnswer( left.join(right, $"left.N" === $"right.N", "full"), @@ -442,7 +443,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted left semi join operator selection") { - ctx.cacheManager.clearCache() + sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 045fea82e4c89..e3531d0d6d799 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -29,4 +29,42 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row("alice", "5")) } + + val tuples: Seq[(String, String)] = + ("1", """{"f1": "value1", "f2": "value2", "f3": 3, "f5": 5.23}""") :: + ("2", """{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: + ("3", """{"f1": "value13", "f4": "value44", "f3": "value33", "f2": 2, "f5": 5.01}""") :: + ("4", null) :: + ("5", """{"f1": "", "f5": null}""") :: + ("6", "[invalid JSON string]") :: + Nil + + test("json_tuple select") { + val df: DataFrame = tuples.toDF("key", "jstring") + val expected = Row("1", Row("value1", "value2", "3", null, "5.23")) :: + Row("2", Row("value12", "2", "value3", "4.01", null)) :: + Row("3", Row("value13", "2", "value33", "value44", "5.01")) :: + Row("4", Row(null, null, null, null, null)) :: + Row("5", Row("", null, null, null, null)) :: + Row("6", Row(null, null, null, null, null)) :: + Nil + + checkAnswer(df.selectExpr("key", "json_tuple(jstring, 'f1', 'f2', 'f3', 'f4', 'f5')"), expected) + } + + test("json_tuple filter and group") { + val df: DataFrame = tuples.toDF("key", "jstring") + val expr = df + .selectExpr("json_tuple(jstring, 'f1', 'f2') as jt") + .where($"jt.c0".isNotNull) + .groupBy($"jt.c1") + .count() + + val expected = Row(null, 1) :: + Row("2", 2) :: + Row("value2", 1) :: + Nil + + checkAnswer(expr, expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index babf8835d2545..5688f46e5e3d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} +import org.apache.spark.sql.catalyst.TableIdentifier class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { import testImplicits._ @@ -32,33 +33,33 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex } after { - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) } test("get all tables") { checkAnswer( - ctx.tables().filter("tableName = 'ListTablesSuiteTable'"), + sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("getting all Tables with a database name has no impact on returned table names") { checkAnswer( - ctx.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + sqlContext.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable")) + assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("query the returned DataFrame of tables") { @@ -66,7 +67,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(ctx.tables(), sql("SHOW TABLes")).foreach { + Seq(sqlContext.tables(), sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) @@ -77,9 +78,9 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex Row(true, "ListTablesSuiteTable") ) checkAnswer( - ctx.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + sqlContext.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), Row("tables", true)) - ctx.dropTempTable("tables") + sqlContext.dropTempTable("tables") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 30289c3c1d097..58f982c2bc932 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -37,9 +37,11 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { private lazy val nullDoubles = Seq(NullDoubles(1.0), NullDoubles(2.0), NullDoubles(3.0), NullDoubles(null)).toDF() - private def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + private def testOneToOneMathFunction[ + @specialized(Int, Long, Float, Double) T, + @specialized(Int, Long, Float, Double) U]( c: Column => Column, - f: T => T): Unit = { + f: T => U): Unit = { checkAnswer( doubleData.select(c('a)), (1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T]))) @@ -165,10 +167,10 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { } test("ceil and ceiling") { - testOneToOneMathFunction(ceil, math.ceil) + testOneToOneMathFunction(ceil, (d: Double) => math.ceil(d).toLong) checkAnswer( sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), - Row(0.0, 1.0, 2.0)) + Row(0L, 1L, 2L)) } test("conv") { @@ -184,7 +186,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { } test("floor") { - testOneToOneMathFunction(floor, math.floor) + testOneToOneMathFunction(floor, (d: Double) => math.floor(d).toLong) } test("factorial") { @@ -228,7 +230,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { } test("signum / sign") { - testOneToOneMathFunction[Double](signum, math.signum) + testOneToOneMathFunction[Double, Double](signum, math.signum) checkAnswer( sql("SELECT sign(10), signum(-11)"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala new file mode 100644 index 0000000000000..0e8fcb6a858b1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala @@ -0,0 +1,99 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.apache.spark._ +import org.scalatest.BeforeAndAfterAll + +class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { + + private var originalActiveSQLContext: Option[SQLContext] = _ + private var originalInstantiatedSQLContext: Option[SQLContext] = _ + private var sparkConf: SparkConf = _ + + override protected def beforeAll(): Unit = { + originalActiveSQLContext = SQLContext.getActiveContextOption() + originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() + + SQLContext.clearActive() + originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx)) + sparkConf = + new SparkConf(false) + .setMaster("local[*]") + .setAppName("test") + .set("spark.ui.enabled", "false") + .set("spark.driver.allowMultipleContexts", "true") + } + + override protected def afterAll(): Unit = { + // Set these states back. + originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx)) + originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx)) + } + + def testNewSession(rootSQLContext: SQLContext): Unit = { + // Make sure we can successfully create new Session. + rootSQLContext.newSession() + + // Reset the state. It is always safe to clear the active context. + SQLContext.clearActive() + } + + def testCreatingNewSQLContext(allowsMultipleContexts: Boolean): Unit = { + val conf = + sparkConf + .clone + .set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowsMultipleContexts.toString) + val sparkContext = new SparkContext(conf) + + try { + if (allowsMultipleContexts) { + new SQLContext(sparkContext) + SQLContext.clearActive() + } else { + // If allowsMultipleContexts is false, make sure we can get the error. + val message = intercept[SparkException] { + new SQLContext(sparkContext) + }.getMessage + assert(message.contains("Only one SQLContext/HiveContext may be running")) + } + } finally { + sparkContext.stop() + } + } + + test("test the flag to disallow creating multiple root SQLContext") { + Seq(false, true).foreach { allowMultipleSQLContexts => + val conf = + sparkConf + .clone + .set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowMultipleSQLContexts.toString) + val sc = new SparkContext(conf) + try { + val rootSQLContext = new SQLContext(sc) + testNewSession(rootSQLContext) + testNewSession(rootSQLContext) + testCreatingNewSQLContext(allowMultipleSQLContexts) + + SQLContext.clearInstantiatedContext(rootSQLContext) + } finally { + sc.stop() + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 3649c2a97b5ef..aba567512fe32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -20,12 +20,16 @@ package org.apache.spark.sql import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ +import scala.reflect.runtime.universe._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.catalyst.encoders.{ProductEncoder, Encoder} -class QueryTest extends PlanTest { +abstract class QueryTest extends PlanTest { + + protected def sqlContext: SQLContext // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -51,23 +55,44 @@ class QueryTest extends PlanTest { } } + protected def checkAnswer[T : Encoder](ds: => Dataset[T], expectedAnswer: T*): Unit = { + checkAnswer( + ds.toDF(), + sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq) + } + /** * Runs the plan and makes sure the answer matches the expected result. * @param df the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ - protected def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = { - QueryTest.checkAnswer(df, expectedAnswer) match { + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + val analyzedDF = try df catch { + case ae: AnalysisException => + val currentValue = sqlContext.conf.dataFrameEagerAnalysis + sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) + val partiallyAnalzyedPlan = df.queryExecution.analyzed + sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, currentValue) + fail( + s""" + |Failed to analyze query: $ae + |$partiallyAnalzyedPlan + | + |${stackTraceToString(ae)} + |""".stripMargin) + } + + QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) case None => } } - protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = { + protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { checkAnswer(df, Seq(expectedAnswer)) } - protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { checkAnswer(df, expectedAnswer.collect()) } @@ -98,19 +123,26 @@ object QueryTest { */ def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + + // We need to call prepareRow recursively to handle schemas with struct types. + def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case d: java.math.BigDecimal => BigDecimal(d) + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + case o => o + }) + } + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for // equality test. - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case b: Array[Byte] => b.toSeq - case o => o - }) - } + val converted: Seq[Row] = answer.map(prepareRow) if (!isSorted) converted.sortBy(_.toString()) else converted } val sparkAnswer = try df.collect().toSeq catch { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 77ccd6f775e50..3ba14d7602a62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -57,7 +57,7 @@ class RowSuite extends SparkFunSuite with SharedSQLContext { test("serialize w/ kryo") { val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() - val serializer = new SparkSqlSerializer(ctx.sparkContext.getConf) + val serializer = new SparkSqlSerializer(sparkContext.getConf) val instance = serializer.newInstance() val ser = instance.serialize(row) val de = instance.deserialize(ser).asInstanceOf[Row] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 7699adadd9cc8..3d2bd236ceead 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{TestSQLContext, SharedSQLContext} class SQLConfSuite extends QueryTest with SharedSQLContext { @@ -27,58 +27,67 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { test("propagate from spark conf") { // We create a new context here to avoid order dependence with other tests that might call // clear(). - val newContext = new SQLContext(ctx.sparkContext) + val newContext = new SQLContext(sparkContext) assert(newContext.getConf("spark.sql.testkey", "false") === "true") } test("programmatic ways of basic setting and getting") { - ctx.conf.clear() - assert(ctx.getAllConfs.size === 0) - - ctx.setConf(testKey, testVal) - assert(ctx.getConf(testKey) === testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getAllConfs.contains(testKey)) + // Set a conf first. + sqlContext.setConf(testKey, testVal) + // Clear the conf. + sqlContext.conf.clear() + // After clear, only overrideConfs used by unit test should be in the SQLConf. + assert(sqlContext.getAllConfs === TestSQLContext.overrideConfs) + + sqlContext.setConf(testKey, testVal) + assert(sqlContext.getConf(testKey) === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getAllConfs.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(ctx.getConf(testKey) == testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getAllConfs.contains(testKey)) + assert(sqlContext.getConf(testKey) === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getAllConfs.contains(testKey)) - ctx.conf.clear() + sqlContext.conf.clear() } test("parse SQL set commands") { - ctx.conf.clear() + sqlContext.conf.clear() sql(s"set $testKey=$testVal") - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) sql("set some.property=20") - assert(ctx.getConf("some.property", "0") === "20") + assert(sqlContext.getConf("some.property", "0") === "20") sql("set some.property = 40") - assert(ctx.getConf("some.property", "0") === "40") + assert(sqlContext.getConf("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" sql(s"set $key=$vs") - assert(ctx.getConf(key, "0") === vs) + assert(sqlContext.getConf(key, "0") === vs) sql(s"set $key=") - assert(ctx.getConf(key, "0") === "") + assert(sqlContext.getConf(key, "0") === "") - ctx.conf.clear() + sqlContext.conf.clear() } test("deprecated property") { - ctx.conf.clear() - sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(ctx.conf.numShufflePartitions === 10) + sqlContext.conf.clear() + val original = sqlContext.conf.numShufflePartitions + try{ + sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + assert(sqlContext.conf.numShufflePartitions === 10) + } finally { + sql(s"set ${SQLConf.SHUFFLE_PARTITIONS}=$original") + } } test("invalid conf value") { - ctx.conf.clear() + sqlContext.conf.clear() val e = intercept[IllegalArgumentException] { sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 007be12950774..1994dacfc4dfa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -17,33 +17,52 @@ package org.apache.spark.sql -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.{SharedSparkContext, SparkFunSuite} -class SQLContextSuite extends SparkFunSuite with SharedSQLContext { - - override def afterAll(): Unit = { - try { - SQLContext.setLastInstantiatedContext(ctx) - } finally { - super.afterAll() - } - } +class SQLContextSuite extends SparkFunSuite with SharedSparkContext{ test("getOrCreate instantiates SQLContext") { - SQLContext.clearLastInstantiatedContext() - val sqlContext = SQLContext.getOrCreate(ctx.sparkContext) + val sqlContext = SQLContext.getOrCreate(sc) assert(sqlContext != null, "SQLContext.getOrCreate returned null") - assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(sc).eq(sqlContext), "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") } - test("getOrCreate gets last explicitly instantiated SQLContext") { - SQLContext.clearLastInstantiatedContext() - val sqlContext = new SQLContext(ctx.sparkContext) - assert(SQLContext.getOrCreate(ctx.sparkContext) != null, - "SQLContext.getOrCreate after explicitly created SQLContext returned null") - assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), + test("getOrCreate return the original SQLContext") { + val sqlContext = SQLContext.getOrCreate(sc) + val newSession = sqlContext.newSession() + assert(SQLContext.getOrCreate(sc).eq(sqlContext), "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") + SQLContext.setActive(newSession) + assert(SQLContext.getOrCreate(sc).eq(newSession), + "SQLContext.getOrCreate after explicitly setActive() did not return the active context") + } + + test("Sessions of SQLContext") { + val sqlContext = SQLContext.getOrCreate(sc) + val session1 = sqlContext.newSession() + val session2 = sqlContext.newSession() + + // all have the default configurations + val key = SQLConf.SHUFFLE_PARTITIONS.key + assert(session1.getConf(key) === session2.getConf(key)) + session1.setConf(key, "1") + session2.setConf(key, "2") + assert(session1.getConf(key) === "1") + assert(session2.getConf(key) === "2") + + // temporary table should not be shared + val df = session1.range(10) + df.registerTempTable("test1") + assert(session1.tableNames().contains("test1")) + assert(!session2.tableNames().contains("test1")) + + // UDF should not be shared + def myadd(a: Int, b: Int): Int = a + b + session1.udf.register[Int, Int, Int]("myadd", myadd) + session1.sql("select myadd(1, 2)").explain() + intercept[AnalysisException] { + session2.sql("select myadd(1, 2)").explain() + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 9e172b2c264cb..298c32290697a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -25,9 +25,10 @@ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.execution.joins.{SortMergeJoin, CartesianProduct} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestData._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ @@ -147,14 +148,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SQL Dialect Switching to a new SQL parser") { - val newContext = new SQLContext(sqlContext.sparkContext) + val newContext = new SQLContext(sparkContext) newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) assert(newContext.getSQLDialect().getClass === classOf[MyDialect]) assert(newContext.sql("SELECT 1").collect() === Array(Row(1))) } test("SQL Dialect Switch to an invalid parser with alias") { - val newContext = new SQLContext(sqlContext.sparkContext) + val newContext = new SQLContext(sparkContext) newContext.sql("SET spark.sql.dialect=MyTestClass") intercept[DialectException] { newContext.sql("SELECT 1") @@ -196,7 +197,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("grouping on nested fields") { - sqlContext.read.json(sqlContext.sparkContext.parallelize( + sqlContext.read.json(sparkContext.parallelize( """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") @@ -215,7 +216,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-6201 IN type conversion") { sqlContext.read.json( - sqlContext.sparkContext.parallelize( + sparkContext.parallelize( Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) .registerTempTable("d") @@ -328,6 +329,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { testCodeGen( "SELECT min(key) FROM testData3x", Row(1) :: Nil) + // STDDEV + testCodeGen( + "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a", + (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25)))) + testCodeGen( + "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2", + Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil) // Some combinations. testCodeGen( """ @@ -348,8 +356,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( - "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(null, null, 0) :: Nil) + "SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData", + Row(null, null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) @@ -515,8 +523,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregates with nulls") { checkAnswer( - sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), - Row(1, 3, 2, 6, 3) + sql("SELECT MIN(a), MAX(a), AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), + Row(1, 3, 2, 1, 6, 3) ) } @@ -574,28 +582,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { mapData.collect().sortBy(_.data(1)).reverse.map(Row.fromTuple).toSeq) } - test("sorting") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false") { - sortTest() - } - } - test("external sorting") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true") { - sortTest() - } - } - - test("SPARK-6927 sorting with codegen on") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "true") { - sortTest() - } + sortTest() } test("SPARK-6927 external sorting with codegen on") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true", - SQLConf.CODEGEN_ENABLED.key -> "true") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { sortTest() } } @@ -722,6 +714,33 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("stddev") { + checkAnswer( + sql("SELECT STDDEV(a) FROM testData2"), + Row(math.sqrt(4/5.0)) + ) + } + + test("stddev_pop") { + checkAnswer( + sql("SELECT STDDEV_POP(a) FROM testData2"), + Row(math.sqrt(4/6.0)) + ) + } + + test("stddev_samp") { + checkAnswer( + sql("SELECT STDDEV_SAMP(a) FROM testData2"), + Row(math.sqrt(4/5.0)) + ) + } + + test("stddev agg") { + checkAnswer( + sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), + (1 to 3).map(i => Row(i, math.sqrt(1/2.0), math.sqrt(1/4.0), math.sqrt(1/2.0)))) + } + test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), @@ -832,6 +851,19 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(null, null, 6, "F") :: Nil) } + test("SPARK-11111 null-safe join should not use cartesian product") { + val df = sql("select count(*) from testData a join testData b on (a.key <=> b.key)") + val cp = df.queryExecution.executedPlan.collect { + case cp: CartesianProduct => cp + } + assert(cp.isEmpty, "should not use CartesianProduct for null-safe join") + val smj = df.queryExecution.executedPlan.collect { + case smj: SortMergeJoin => smj + } + assert(smj.size > 0, "should use SortMergeJoin") + checkAnswer(df, Row(100) :: Nil) + } + test("SPARK-3349 partitioning after limit") { sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC") .limit(2) @@ -991,21 +1023,30 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val nonexistentKey = "nonexistent" // "set" itself returns all config variables currently specified in SQLConf. - assert(sql("SET").collect().size == 0) + assert(sql("SET").collect().size === TestSQLContext.overrideConfs.size) + sql("SET").collect().foreach { row => + val key = row.getString(0) + val value = row.getString(1) + assert( + TestSQLContext.overrideConfs.contains(key), + s"$key should exist in SQLConf.") + assert( + TestSQLContext.overrideConfs(key) === value, + s"The value of $key should be ${TestSQLContext.overrideConfs(key)} instead of $value.") + } + val overrideConfs = sql("SET").collect() // "set key=val" sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Row(testKey, testVal) + overrideConfs ++ Seq(Row(testKey, testVal)) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), - Seq( - Row(testKey, testVal), - Row(testKey + testKey, testVal + testVal)) + overrideConfs ++ Seq(Row(testKey, testVal), Row(testKey + testKey, testVal + testVal)) ) // "set key" @@ -1342,7 +1383,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-3483 Special chars in column names") { - val data = sqlContext.sparkContext.parallelize( + val data = sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) sqlContext.read.json(data).registerTempTable("records") sql("SELECT `key?number1`, `key.number2` FROM records") @@ -1385,13 +1426,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-4322 Grouping field with struct field as sub expression") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + sqlContext.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) .registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) sqlContext.dropTempTable("data") sqlContext.read.json( - sqlContext.sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) sqlContext.dropTempTable("data") } @@ -1412,10 +1453,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Supporting relational operator '<=>' in Spark SQL") { val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd1 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) + val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) rdd1.toDF().registerTempTable("nulldata1") val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd2 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) + val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) rdd2.toDF().registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), @@ -1424,7 +1465,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Multi-column COUNT(DISTINCT ...)") { val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } @@ -1432,14 +1473,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-4699 case sensitivity SQL query") { sqlContext.setConf(SQLConf.CASE_SENSITIVE, false) val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("testTable1") checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) sqlContext.setConf(SQLConf.CASE_SENSITIVE, true) } test("SPARK-6145: ORDER BY test for nested fields") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") @@ -1452,14 +1493,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6145: special cases") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") @@ -1490,6 +1531,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { """.stripMargin), Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) + checkAnswer( + sql( + """ + |SELECT sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b), max(b) + """.stripMargin), + Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) + checkAnswer( sql( """ @@ -1509,6 +1560,26 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { |ORDER BY sum(b) + 1 """.stripMargin), Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) + + checkAnswer( + sql( + """ + |SELECT count(*) + |FROM orderByData + |GROUP BY a + |ORDER BY count(*) + """.stripMargin), + Row(2) :: Row(2) :: Row(2) :: Row(2) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a + |FROM orderByData + |GROUP BY a + |ORDER BY a, count(*), sum(b) + """.stripMargin), + Row("1") :: Row("2") :: Row("3") :: Row("4") :: Nil) } test("SPARK-7952: fix the equality check between boolean and numeric types") { @@ -1533,7 +1604,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-7067: order by queries for complex ExtractValue chain") { withTempTable("t") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) } @@ -1600,8 +1671,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregation with codegen updates peak execution memory") { withSQLConf((SQLConf.CODEGEN_ENABLED.key, "true")) { - val sc = sqlContext.sparkContext - AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with codegen") { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "aggregation with codegen") { testCodeGen( "SELECT key, count(value) FROM testData GROUP BY key", (1 to 100).map(i => Row(i, 1))) @@ -1659,17 +1729,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("external sorting updates peak execution memory") { - withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { - val sc = sqlContext.sparkContext - AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sort") { - sortTest() - } + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { + sortTest() } } test("SPARK-9511: error with table starting with number") { withTempTable("1one") { - sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) .toDF("num", "str") .registerTempTable("1one") checkAnswer(sql("select count(num) from 1one"), Row(10)) @@ -1680,7 +1747,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { withTempPath { dir => val path = dir.getCanonicalPath val df = - sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") df .write .format("parquet") @@ -1712,9 +1779,85 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-10130 type coercion for IF should have children resolved first") { - val df = Seq((1, 1), (-1, 1)).toDF("key", "value") - df.registerTempTable("src") - checkAnswer( - sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0))) + withTempTable("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + checkAnswer( + sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0))) + } + } + + test("SPARK-10389: order by non-attribute grouping expression on Aggregate") { + withTempTable("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1"), + Seq(Row(1), Row(1))) + checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY (key + 1) * 2"), + Seq(Row(1), Row(1))) + } + } + + test("run sql directly on files") { + val df = sqlContext.range(100) + withTempPath(f => { + df.write.json(f.getCanonicalPath) + checkAnswer(sql(s"select id from json.`${f.getCanonicalPath}`"), + df) + checkAnswer(sql(s"select id from `org.apache.spark.sql.json`.`${f.getCanonicalPath}`"), + df) + checkAnswer(sql(s"select a.id from json.`${f.getCanonicalPath}` as a"), + df) + }) + + val e1 = intercept[AnalysisException] { + sql("select * from in_valid_table") + } + assert(e1.message.contains("Table not found")) + + val e2 = intercept[AnalysisException] { + sql("select * from no_db.no_table") + } + assert(e2.message.contains("Table not found")) + + val e3 = intercept[AnalysisException] { + sql("select * from json.invalid_file") + } + assert(e3.message.contains("No input paths specified")) + } + + test("SortMergeJoin returns wrong results when using UnsafeRows") { + // This test is for the fix of https://issues.apache.org/jira/browse/SPARK-10737. + // This bug will be triggered when Tungsten is enabled and there are multiple + // SortMergeJoin operators executed in the same task. + val confs = + SQLConf.SORTMERGE_JOIN.key -> "true" :: + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" :: + SQLConf.TUNGSTEN_ENABLED.key -> "true" :: Nil + withSQLConf(confs: _*) { + val df1 = (1 to 50).map(i => (s"str_$i", i)).toDF("i", "j") + val df2 = + df1 + .join(df1.select(df1("i")), "i") + .select(df1("i"), df1("j")) + + val df3 = df2.withColumnRenamed("i", "i1").withColumnRenamed("j", "j1") + val df4 = + df2 + .join(df3, df2("i") === df3("i1")) + .withColumn("diff", $"j" - $"j1") + .select(df2("i"), df2("j"), $"diff") + + checkAnswer( + df4, + df1.withColumn("diff", lit(0))) + } + } + + test("SPARK-11032: resolve having correctly") { + withTempTable("src") { + Seq(1 -> "a").toDF("i", "j").registerTempTable("src") + checkAnswer( + sql("SELECT MIN(t.i) FROM (SELECT * FROM src WHERE i > 0) t HAVING(COUNT(1) > 0)"), + Row(1)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index 45d0ee4a8e749..ddab918629645 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.test.SharedSQLContext class SerializationSuite extends SparkFunSuite with SharedSQLContext { test("[SPARK-5235] SQLContext should be serializable") { - val _sqlContext = new SQLContext(sqlContext.sparkContext) + val _sqlContext = new SQLContext(sparkContext) new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index b91438baea06f..e12e6bea30260 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -268,9 +268,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row(3, 4)) intercept[AnalysisException] { - checkAnswer( - df.selectExpr("length(c)"), // int type of the argument is unacceptable - Row("5.0000")) + df.selectExpr("length(c)") // int type of the argument is unacceptable } } @@ -284,63 +282,46 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("number format function") { - val tuple = - ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], - 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381)) - val df = - Seq(tuple) - .toDF( - "a", // string "aa" - "b", // byte 1 - "c", // short 2 - "d", // float 3.13223f - "e", // integer 4 - "f", // long 5L - "g", // double 6.48173d - "h") // decimal 7.128381 - - checkAnswer( - df.select(format_number($"f", 4)), + val df = sqlContext.range(1) + + checkAnswer( + df.select(format_number(lit(5L), 4)), Row("5.0000")) checkAnswer( - df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer + df.select(format_number(lit(1.toByte), 4)), // convert the 1st argument to integer Row("1.0000")) checkAnswer( - df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer + df.select(format_number(lit(2.toShort), 4)), // convert the 1st argument to integer Row("2.0000")) checkAnswer( - df.selectExpr("format_number(d, e)"), // convert the 1st argument to double + df.select(format_number(lit(3.1322.toFloat), 4)), // convert the 1st argument to double Row("3.1322")) checkAnswer( - df.selectExpr("format_number(e, e)"), // not convert anything + df.select(format_number(lit(4), 4)), // not convert anything Row("4.0000")) checkAnswer( - df.selectExpr("format_number(f, e)"), // not convert anything + df.select(format_number(lit(5L), 4)), // not convert anything Row("5.0000")) checkAnswer( - df.selectExpr("format_number(g, e)"), // not convert anything + df.select(format_number(lit(6.48173), 4)), // not convert anything Row("6.4817")) checkAnswer( - df.selectExpr("format_number(h, e)"), // not convert anything + df.select(format_number(lit(BigDecimal(7.128381)), 4)), // not convert anything Row("7.1284")) intercept[AnalysisException] { - checkAnswer( - df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable - Row("5.0000")) + df.select(format_number(lit("aa"), 4)) // string type of the 1st argument is unacceptable } intercept[AnalysisException] { - checkAnswer( - df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable - Row("5.0000")) + df.selectExpr("format_number(4, 6.48173)") // non-integral type 2nd argument is unacceptable } // for testing the mutable state of the expression in code gen. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index eb275af101e2f..e0435a0dba6ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -26,7 +26,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("built-in fixed arity expressions") { - val df = ctx.emptyDataFrame + val df = sqlContext.emptyDataFrame df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") } @@ -55,23 +55,23 @@ class UDFSuite extends QueryTest with SharedSQLContext { val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") df.registerTempTable("tmp_table") checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) - ctx.dropTempTable("tmp_table") + sqlContext.dropTempTable("tmp_table") } test("SPARK-8005 input_file_name") { withTempPath { dir => - val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id") + val data = sparkContext.parallelize(0 to 10, 2).toDF("id") data.write.parquet(dir.getCanonicalPath) - ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") + sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") val answer = sql("select input_file_name() from test_table").head().getString(0) assert(answer.contains(dir.getCanonicalPath)) assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2) - ctx.dropTempTable("test_table") + sqlContext.dropTempTable("test_table") } } test("error reporting for incorrect number of arguments") { - val df = ctx.emptyDataFrame + val df = sqlContext.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } @@ -79,7 +79,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("error reporting for undefined functions") { - val df = ctx.emptyDataFrame + val df = sqlContext.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("a_function_that_does_not_exist()") } @@ -87,24 +87,24 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("Simple UDF") { - ctx.udf.register("strLenScala", (_: String).length) + sqlContext.udf.register("strLenScala", (_: String).length) assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { - ctx.udf.register("random0", () => { Math.random()}) + sqlContext.udf.register("random0", () => { Math.random()}) assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - ctx.udf.register("strLenScala", (_: String).length + (_: Int)) + sqlContext.udf.register("strLenScala", (_: String).length + (_: Int)) assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("UDF in a WHERE") { - ctx.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + sqlContext.udf.register("oneArgFilter", (n: Int) => { n > 80 }) - val df = ctx.sparkContext.parallelize( + val df = sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() df.registerTempTable("integerData") @@ -114,7 +114,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a HAVING") { - ctx.udf.register("havingFilter", (n: Long) => { n > 5 }) + sqlContext.udf.register("havingFilter", (n: Long) => { n > 5 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -133,7 +133,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a GROUP BY") { - ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) + sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -150,10 +150,10 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDFs everywhere") { - ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) - ctx.udf.register("havingFilter", (n: Long) => { n > 2000 }) - ctx.udf.register("whereFilter", (n: Int) => { n < 150 }) - ctx.udf.register("timesHundred", (n: Long) => { n * 100 }) + sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) + sqlContext.udf.register("havingFilter", (n: Long) => { n > 2000 }) + sqlContext.udf.register("whereFilter", (n: Int) => { n < 150 }) + sqlContext.udf.register("timesHundred", (n: Long) => { n * 100 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -172,7 +172,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("struct UDF") { - ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + sqlContext.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = sql("SELECT returnStruct('test', 'test2') as ret") @@ -181,13 +181,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("udf that is transformed") { - ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + sqlContext.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } test("type coercion for udf inputs") { - ctx.udf.register("intExpected", (x: Int) => x) + sqlContext.udf.register("intExpected", (x: Int) => x) // pass a decimal to intExpected. assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 2476b10e3cf9e..00f1526576cc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql import java.io.ByteArrayOutputStream -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.memory.MemoryAllocator @@ -29,6 +31,32 @@ import org.apache.spark.unsafe.types.UTF8String class UnsafeRowSuite extends SparkFunSuite { + test("UnsafeRow Java serialization") { + // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data + val data = new Array[Byte](1024) + val row = new UnsafeRow + row.pointTo(data, 1, 16) + row.setLong(0, 19285) + + val ser = new JavaSerializer(new SparkConf).newInstance() + val row1 = ser.deserialize[UnsafeRow](ser.serialize(row)) + assert(row1.getLong(0) == 19285) + assert(row1.getBaseObject().asInstanceOf[Array[Byte]].length == 16) + } + + test("UnsafeRow Kryo serialization") { + // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data + val data = new Array[Byte](1024) + val row = new UnsafeRow + row.pointTo(data, 1, 16) + row.setLong(0, 19285) + + val ser = new KryoSerializer(new SparkConf).newInstance() + val row1 = ser.deserialize[UnsafeRow](ser.serialize(row)) + assert(row1.getLong(0) == 19285) + assert(row1.getBaseObject().asInstanceOf[Array[Byte]].length == 16) + } + test("bitset width calculation") { assert(UnsafeRow.calculateBitSetWidthInBytes(0) === 0) assert(UnsafeRow.calculateBitSetWidthInBytes(1) === 8) @@ -131,4 +159,11 @@ class UnsafeRowSuite extends SparkFunSuite { assert(emptyRow.getInt(0) === unsafeRow.getInt(0)) assert(emptyRow.getUTF8String(1) === unsafeRow.getUTF8String(1)) } + + test("calling hashCode on unsafe array returned by getArray(ordinal)") { + val row = InternalRow.apply(new GenericArrayData(Array(1L))) + val unsafeRow = UnsafeProjection.create(Array[DataType](ArrayType(LongType))).apply(row) + // Makes sure hashCode on unsafe array won't crash + unsafeRow.getArray(0).hashCode() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index b6d279ae47268..a229e5814df89 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,12 +17,16 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} + import scala.beans.{BeanInfo, BeanProperty} import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -67,7 +71,7 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { private[spark] override def asNullable: MyDenseVectorUDT = this } -class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { +class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { import testImplicits._ private lazy val pointsRDD = Seq( @@ -90,24 +94,35 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { } test("UDTs and UDFs") { - ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + sqlContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } - - test("UDTs with Parquet") { - val tempDir = Utils.createTempDir() - tempDir.delete() - pointsRDD.write.parquet(tempDir.getCanonicalPath) + testStandardAndLegacyModes("UDTs with Parquet") { + withTempPath { dir => + val path = dir.getCanonicalPath + pointsRDD.write.parquet(path) + checkAnswer( + sqlContext.read.parquet(path), + Seq( + Row(1.0, new MyDenseVector(Array(0.1, 1.0))), + Row(0.0, new MyDenseVector(Array(0.2, 2.0))))) + } } - test("Repartition UDTs with Parquet") { - val tempDir = Utils.createTempDir() - tempDir.delete() - pointsRDD.repartition(1).write.parquet(tempDir.getCanonicalPath) + testStandardAndLegacyModes("Repartition UDTs with Parquet") { + withTempPath { dir => + val path = dir.getCanonicalPath + pointsRDD.repartition(1).write.parquet(path) + checkAnswer( + sqlContext.read.parquet(path), + Seq( + Row(1.0, new MyDenseVector(Array(0.1, 1.0))), + Row(0.0, new MyDenseVector(Array(0.2, 2.0))))) + } } // Tests to make sure that all operators correctly convert types on the way out. @@ -148,8 +163,8 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { StructField("vec", new MyDenseVectorUDT, false) )) - val stringRDD = ctx.sparkContext.parallelize(data) - val jsonRDD = ctx.read.schema(schema).json(stringRDD) + val stringRDD = sparkContext.parallelize(data) + val jsonRDD = sqlContext.read.schema(schema).json(stringRDD) checkAnswer( jsonRDD, Row(1, new MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: @@ -157,4 +172,20 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { Nil ) } + + test("SPARK-10472 UserDefinedType.typeName") { + assert(IntegerType.typeName === "integer") + assert(new MyDenseVectorUDT().typeName === "mydensevector") + assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset") + } + + test("Catalyst type converter null handling for UDTs") { + val udt = new MyDenseVectorUDT() + val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt) + assert(toScalaConverter(null) === null) + + val toCatalystConverter = CatalystTypeConverters.createToCatalystConverter(udt) + assert(toCatalystConverter(null) === null) + + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index d0430d2a60e75..89a664001bdd2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ @@ -27,10 +26,7 @@ class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0)) testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0)) testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[DateColumnStats], DATE, createRow(Int.MaxValue, Int.MinValue, 0)) testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, - createRow(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, createRow(Double.MaxValue, Double.MinValue, 0)) @@ -79,11 +75,11 @@ class ColumnStatsSuite extends SparkFunSuite { def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( initialStatistics: GenericInternalRow): Unit = { - val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName - val columnType = FIXED_DECIMAL(15, 10) + val columnStatsName = classOf[DecimalColumnStats].getSimpleName + val columnType = COMPACT_DECIMAL(15, 10) test(s"$columnStatsName: empty") { - val columnStats = new FixedDecimalColumnStats(15, 10) + val columnStats = new DecimalColumnStats(15, 10) columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } @@ -92,7 +88,7 @@ class ColumnStatsSuite extends SparkFunSuite { test(s"$columnStatsName: non-empty") { import org.apache.spark.sql.columnar.ColumnarTestUtils._ - val columnStats = new FixedDecimalColumnStats(15, 10) + val columnStats = new DecimalColumnStats(15, 10) val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 8f024690efd0d..63bc39bfa0307 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -17,30 +17,27 @@ package org.apache.spark.sql.columnar -import java.nio.ByteBuffer +import java.nio.{ByteOrder, ByteBuffer} -import com.esotericsoftware.kryo.io.{Input, Output} -import com.esotericsoftware.kryo.{Kryo, Serializer} - -import org.apache.spark.{Logging, SparkConf, SparkFunSuite} -import org.apache.spark.serializer.KryoRegistrator -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{Logging, SparkFunSuite} class ColumnTypeSuite extends SparkFunSuite with Logging { private val DEFAULT_BUFFER_SIZE = 512 - private val MAP_GENERIC = GENERIC(MapType(IntegerType, StringType)) + private val MAP_TYPE = MAP(MapType(IntegerType, StringType)) + private val ARRAY_TYPE = ARRAY(ArrayType(IntegerType)) + private val STRUCT_TYPE = STRUCT(StructType(StructField("a", StringType) :: Nil)) test("defaultSize") { val checks = Map( - BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, DATE -> 4, - LONG -> 8, TIMESTAMP -> 8, FLOAT -> 4, DOUBLE -> 8, - STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, - MAP_GENERIC -> 16) + NULL-> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8, + FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12, + STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 16, MAP_TYPE -> 32) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -50,203 +47,80 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } test("actualSize") { - def checkActualSize[JvmType]( - columnType: ColumnType[JvmType], - value: JvmType, + def checkActualSize( + columnType: ColumnType[_], + value: Any, expected: Int): Unit = { assertResult(expected, s"Wrong actualSize for $columnType") { val row = new GenericMutableRow(1) - columnType.setField(row, 0, value) - columnType.actualSize(row, 0) + row.update(0, CatalystTypeConverters.convertToCatalyst(value)) + val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) + columnType.actualSize(proj(row), 0) } } + checkActualSize(NULL, null, 0) checkActualSize(BOOLEAN, true, 1) checkActualSize(BYTE, Byte.MaxValue, 1) checkActualSize(SHORT, Short.MaxValue, 2) checkActualSize(INT, Int.MaxValue, 4) - checkActualSize(DATE, Int.MaxValue, 4) checkActualSize(LONG, Long.MaxValue, 8) - checkActualSize(TIMESTAMP, Long.MaxValue, 8) checkActualSize(FLOAT, Float.MaxValue, 4) checkActualSize(DOUBLE, Double.MaxValue, 8) - checkActualSize(STRING, UTF8String.fromString("hello"), 4 + "hello".getBytes("utf-8").length) + checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) - checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) - - val generic = Map(1 -> "a") - checkActualSize(MAP_GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) - } - - testNativeColumnType(BOOLEAN)( - (buffer: ByteBuffer, v: Boolean) => { - buffer.put((if (v) 1 else 0).toByte) - }, - (buffer: ByteBuffer) => { - buffer.get() == 1 - }) - - testNativeColumnType(BYTE)(_.put(_), _.get) - - testNativeColumnType(SHORT)(_.putShort(_), _.getShort) - - testNativeColumnType(INT)(_.putInt(_), _.getInt) - - testNativeColumnType(DATE)(_.putInt(_), _.getInt) - - testNativeColumnType(LONG)(_.putLong(_), _.getLong) - - testNativeColumnType(TIMESTAMP)(_.putLong(_), _.getLong) - - testNativeColumnType(FLOAT)(_.putFloat(_), _.getFloat) - - testNativeColumnType(DOUBLE)(_.putDouble(_), _.getDouble) - - testNativeColumnType(FIXED_DECIMAL(15, 10))( - (buffer: ByteBuffer, decimal: Decimal) => { - buffer.putLong(decimal.toUnscaledLong) - }, - (buffer: ByteBuffer) => { - Decimal(buffer.getLong(), 15, 10) - }) - - - testNativeColumnType(STRING)( - (buffer: ByteBuffer, string: UTF8String) => { - val bytes = string.getBytes - buffer.putInt(bytes.length) - buffer.put(bytes) - }, - (buffer: ByteBuffer) => { - val length = buffer.getInt() - val bytes = new Array[Byte](length) - buffer.get(bytes) - UTF8String.fromBytes(bytes) - }) - - testColumnType[Array[Byte]]( - BINARY, - (buffer: ByteBuffer, bytes: Array[Byte]) => { - buffer.putInt(bytes.length).put(bytes) - }, - (buffer: ByteBuffer) => { - val length = buffer.getInt() - val bytes = new Array[Byte](length) - buffer.get(bytes, 0, length) - bytes - }) - - test("GENERIC") { - val buffer = ByteBuffer.allocate(512) - val obj = Map(1 -> "spark", 2 -> "sql") - val serializedObj = SparkSqlSerializer.serialize(obj) - - MAP_GENERIC.append(SparkSqlSerializer.serialize(obj), buffer) - buffer.rewind() - - val length = buffer.getInt() - assert(length === serializedObj.length) - - assertResult(obj, "Deserialized object didn't equal to the original object") { - val bytes = new Array[Byte](length) - buffer.get(bytes, 0, length) - SparkSqlSerializer.deserialize(bytes) - } - - buffer.rewind() - buffer.putInt(serializedObj.length).put(serializedObj) - - assertResult(obj, "Deserialized object didn't equal to the original object") { - buffer.rewind() - SparkSqlSerializer.deserialize(MAP_GENERIC.extract(buffer)) - } - } - - test("CUSTOM") { - val conf = new SparkConf() - conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator") - val serializer = new SparkSqlSerializer(conf).newInstance() - - val buffer = ByteBuffer.allocate(512) - val obj = CustomClass(Int.MaxValue, Long.MaxValue) - val serializedObj = serializer.serialize(obj).array() - - MAP_GENERIC.append(serializer.serialize(obj).array(), buffer) - buffer.rewind() - - val length = buffer.getInt - assert(length === serializedObj.length) - assert(13 == length) // id (1) + int (4) + long (8) - - val genericSerializedObj = SparkSqlSerializer.serialize(obj) - assert(length != genericSerializedObj.length) - assert(length < genericSerializedObj.length) - - assertResult(obj, "Custom deserialized object didn't equal the original object") { - val bytes = new Array[Byte](length) - buffer.get(bytes, 0, length) - serializer.deserialize(ByteBuffer.wrap(bytes)) - } - - buffer.rewind() - buffer.putInt(serializedObj.length).put(serializedObj) - - assertResult(obj, "Custom deserialized object didn't equal the original object") { - buffer.rewind() - serializer.deserialize(ByteBuffer.wrap(MAP_GENERIC.extract(buffer))) - } + checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8) + checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5) + checkActualSize(ARRAY_TYPE, Array[Any](1), 16) + checkActualSize(MAP_TYPE, Map(1 -> "a"), 29) + checkActualSize(STRUCT_TYPE, Row("hello"), 28) } - def testNativeColumnType[T <: AtomicType]( - columnType: NativeColumnType[T]) - (putter: (ByteBuffer, T#InternalType) => Unit, - getter: (ByteBuffer) => T#InternalType): Unit = { - - testColumnType[T#InternalType](columnType, putter, getter) + testNativeColumnType(BOOLEAN) + testNativeColumnType(BYTE) + testNativeColumnType(SHORT) + testNativeColumnType(INT) + testNativeColumnType(LONG) + testNativeColumnType(FLOAT) + testNativeColumnType(DOUBLE) + testNativeColumnType(COMPACT_DECIMAL(15, 10)) + testNativeColumnType(STRING) + + testColumnType(NULL) + testColumnType(BINARY) + testColumnType(LARGE_DECIMAL(20, 10)) + testColumnType(STRUCT_TYPE) + testColumnType(ARRAY_TYPE) + testColumnType(MAP_TYPE) + + def testNativeColumnType[T <: AtomicType](columnType: NativeColumnType[T]): Unit = { + testColumnType[T#InternalType](columnType) } - def testColumnType[JvmType]( - columnType: ColumnType[JvmType], - putter: (ByteBuffer, JvmType) => Unit, - getter: (ByteBuffer) => JvmType): Unit = { + def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = { - val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE) - val seq = (0 until 4).map(_ => makeRandomValue(columnType)) + val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE).order(ByteOrder.nativeOrder()) + val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) + val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) + val seq = (0 until 4).map(_ => proj(makeRandomRow(columnType)).copy()) - test(s"$columnType.extract") { + test(s"$columnType append/extract") { buffer.rewind() - seq.foreach(putter(buffer, _)) + seq.foreach(columnType.append(_, 0, buffer)) buffer.rewind() - seq.foreach { expected => - logInfo("buffer = " + buffer + ", expected = " + expected) - val extracted = columnType.extract(buffer) - assert( - expected === extracted, - "Extracted value didn't equal to the original one. " + - hexDump(expected) + " != " + hexDump(extracted) + - ", buffer = " + dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer])) - } - } - - test(s"$columnType.append") { - buffer.rewind() - seq.foreach(columnType.append(_, buffer)) - - buffer.rewind() - seq.foreach { expected => - assert( - expected === getter(buffer), - "Extracted value didn't equal to the original one") + seq.foreach { row => + logInfo("buffer = " + buffer + ", expected = " + row) + val expected = converter(row.get(0, columnType.dataType)) + val extracted = converter(columnType.extract(buffer)) + assert(expected === extracted, + s"Extracted value didn't equal to the original one. $expected != $extracted, buffer =" + + dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer])) } } } - private def hexDump(value: Any): String = { - value.toString.map(ch => Integer.toHexString(ch & 0xffff)).mkString(" ") - } - private def dumpBuffer(buff: ByteBuffer): Any = { val sb = new StringBuilder() while (buff.hasRemaining) { @@ -259,33 +133,13 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { test("column type for decimal types with different precision") { (1 to 18).foreach { i => - assertResult(FIXED_DECIMAL(i, 0)) { + assertResult(COMPACT_DECIMAL(i, 0)) { ColumnType(DecimalType(i, 0)) } } - assertResult(GENERIC(DecimalType(19, 0))) { + assertResult(LARGE_DECIMAL(19, 0)) { ColumnType(DecimalType(19, 0)) } } } - -private[columnar] final case class CustomClass(a: Int, b: Long) - -private[columnar] object CustomerSerializer extends Serializer[CustomClass] { - override def write(kryo: Kryo, output: Output, t: CustomClass) { - output.writeInt(t.a) - output.writeLong(t.b) - } - override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = { - val a = input.readInt() - val b = input.readLong() - CustomClass(a, b) - } -} - -private[columnar] final class Registrator extends KryoRegistrator { - override def registerClasses(kryo: Kryo) { - kryo.register(classOf[CustomClass], CustomerSerializer) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index 79bb7d072feb2..a5882f7870e37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.columnar import scala.collection.immutable.HashSet import scala.util.Random + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{DataType, Decimal, AtomicType} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData} +import org.apache.spark.sql.types.{AtomicType, Decimal} import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { @@ -39,21 +41,25 @@ object ColumnarTestUtils { } (columnType match { + case NULL => null case BOOLEAN => Random.nextBoolean() case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort case INT => Random.nextInt() - case DATE => Random.nextInt() case LONG => Random.nextLong() - case TIMESTAMP => Random.nextLong() case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) case BINARY => randomBytes(Random.nextInt(32)) - case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) - case _ => - // Using a random one-element map instead of an arbitrary object - Map(Random.nextInt() -> Random.nextString(Random.nextInt(32))) + case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) + case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale) + case STRUCT(_) => + new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10)))) + case ARRAY(_) => + new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt())) + case MAP(_) => + ArrayBasedMapData( + Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32))))) }).asInstanceOf[JvmType] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 952637c5f9cb8..6265e40a0a07b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -31,7 +31,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { setupTestData() test("simple columnar query") { - val plan = ctx.executePlan(testData.logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -39,16 +39,16 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("default size avoids broadcast") { // TODO: Improve this test when we have better statistics - ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) .toDF().registerTempTable("sizeTst") - ctx.cacheTable("sizeTst") + sqlContext.cacheTable("sizeTst") assert( - ctx.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > - ctx.conf.autoBroadcastJoinThreshold) + sqlContext.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > + sqlContext.conf.autoBroadcastJoinThreshold) } test("projection") { - val plan = ctx.executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -57,7 +57,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = ctx.executePlan(testData.logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -69,7 +69,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("repeatedData") + sqlContext.cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -81,7 +81,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("nullableRepeatedData") + sqlContext.cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -96,7 +96,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) - ctx.cacheTable("timestamps") + sqlContext.cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), @@ -108,7 +108,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("withEmptyParts") + sqlContext.cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), @@ -157,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Create a RDD for the schema val rdd = - ctx.sparkContext.parallelize((1 to 100), 10).map { i => + sparkContext.parallelize((1 to 10000), 10).map { i => Row( s"str${i}: test cache.", s"binary${i}: test cache.".getBytes("UTF-8"), @@ -172,23 +172,51 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { BigDecimal(Long.MaxValue.toString + ".12345"), new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), new Date(i), - new Timestamp(i), - (1 to i).toSeq, - (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, + new Timestamp(i * 1000000L), + (i to i + 10).toSeq, + (i to i + 10).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, Row((i - 0.25).toFloat, Seq(true, false, null))) } - ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + sqlContext.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. sql("cache table InMemoryCache_different_data_types") // Make sure the table is indeed cached. - val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan + sqlContext.table("InMemoryCache_different_data_types").queryExecution.executedPlan assert( - ctx.isCached("InMemoryCache_different_data_types"), + sqlContext.isCached("InMemoryCache_different_data_types"), "InMemoryCache_different_data_types should be cached.") // Issue a query and check the results. checkAnswer( sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), - ctx.table("InMemoryCache_different_data_types").collect()) - ctx.dropTempTable("InMemoryCache_different_data_types") + sqlContext.table("InMemoryCache_different_data_types").collect()) + sqlContext.dropTempTable("InMemoryCache_different_data_types") + } + + test("SPARK-10422: String column in InMemoryColumnarCache needs to override clone method") { + val df = sqlContext.range(1, 100).selectExpr("id % 10 as id") + .rdd.map(id => Tuple1(s"str_$id")).toDF("i") + val cached = df.cache() + // count triggers the caching action. It should not throw. + cached.count() + + // Make sure, the DataFrame is indeed cached. + assert(sqlContext.cacheManager.lookupCachedData(cached).nonEmpty) + + // Check result. + checkAnswer( + cached, + sqlContext.range(1, 100).selectExpr("id % 10 as id") + .rdd.map(id => Tuple1(s"str_$id")).toDF("i") + ) + + // Drop the cache. + cached.unpersist() + } + + test("SPARK-10859: Predicates pushed to InMemoryColumnarTableScan are not evaluated correctly") { + val data = sqlContext.range(10).selectExpr("id", "cast(id as string) as s") + data.cache() + assert(data.count() === 10) + assert(data.filter($"s" === "3").count() === 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index f4f6c7649bfa8..aa1605fee8c73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{StringType, ArrayType, DataType} +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} +import org.apache.spark.sql.types._ class TestNullableColumnAccessor[JvmType]( buffer: ByteBuffer, @@ -32,18 +33,18 @@ class TestNullableColumnAccessor[JvmType]( object TestNullableColumnAccessor { def apply[JvmType](buffer: ByteBuffer, columnType: ColumnType[JvmType]) : TestNullableColumnAccessor[JvmType] = { - // Skips the column type ID - buffer.getInt() new TestNullableColumnAccessor(buffer, columnType) } } class NullableColumnAccessorSuite extends SparkFunSuite { - import ColumnarTestUtils._ + import org.apache.spark.sql.columnar.ColumnarTestUtils._ Seq( - BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, - STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) + NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, + STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10), + STRUCT(StructType(StructField("a", StringType) :: Nil)), + ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType))) .foreach { testNullableColumnAccessor(_) } @@ -63,19 +64,22 @@ class NullableColumnAccessorSuite extends SparkFunSuite { test(s"Nullable $typeName column accessor: access null values") { val builder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) + val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) (0 until 4).foreach { _ => - builder.appendFrom(randomRow, 0) - builder.appendFrom(nullRow, 0) + builder.appendFrom(proj(randomRow), 0) + builder.appendFrom(proj(nullRow), 0) } val accessor = TestNullableColumnAccessor(builder.build(), columnType) val row = new GenericMutableRow(1) + val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) (0 until 4).foreach { _ => assert(accessor.hasNext) accessor.extractTo(row, 0) - assert(row.get(0, columnType.dataType) === randomRow.get(0, columnType.dataType)) + assert(converter(row.get(0, columnType.dataType)) + === converter(randomRow.get(0, columnType.dataType))) assert(accessor.hasNext) accessor.extractTo(row, 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 241d09ea205e9..91404577832a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} import org.apache.spark.sql.types._ class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) @@ -35,11 +36,13 @@ object TestNullableColumnBuilder { } class NullableColumnBuilderSuite extends SparkFunSuite { - import ColumnarTestUtils._ + import org.apache.spark.sql.columnar.ColumnarTestUtils._ Seq( - BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, - STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) + BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, + STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10), + STRUCT(StructType(StructField("a", StringType) :: Nil)), + ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType))) .foreach { testNullableColumnBuilder(_) } @@ -48,12 +51,14 @@ class NullableColumnBuilderSuite extends SparkFunSuite { columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + val dataType = columnType.dataType + val proj = UnsafeProjection.create(Array[DataType](dataType)) + val converter = CatalystTypeConverters.createToScalaConverter(dataType) test(s"$typeName column builder: empty column") { val columnBuilder = TestNullableColumnBuilder(columnType) val buffer = columnBuilder.build() - assertResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) assertResult(0, "Wrong null count")(buffer.getInt()) assert(!buffer.hasRemaining) } @@ -63,12 +68,11 @@ class NullableColumnBuilderSuite extends SparkFunSuite { val randomRow = makeRandomRow(columnType) (0 until 4).foreach { _ => - columnBuilder.appendFrom(randomRow, 0) + columnBuilder.appendFrom(proj(randomRow), 0) } val buffer = columnBuilder.build() - assertResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) assertResult(0, "Wrong null count")(buffer.getInt()) } @@ -78,27 +82,22 @@ class NullableColumnBuilderSuite extends SparkFunSuite { val nullRow = makeNullRow(1) (0 until 4).foreach { _ => - columnBuilder.appendFrom(randomRow, 0) - columnBuilder.appendFrom(nullRow, 0) + columnBuilder.appendFrom(proj(randomRow), 0) + columnBuilder.appendFrom(proj(nullRow), 0) } val buffer = columnBuilder.build() - assertResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) assertResult(4, "Wrong null count")(buffer.getInt()) // For null positions (1 to 7 by 2).foreach(assertResult(_, "Wrong null position")(buffer.getInt())) // For non-null values + val actual = new GenericMutableRow(new Array[Any](1)) (0 until 4).foreach { _ => - val actual = if (columnType.isInstanceOf[GENERIC]) { - SparkSqlSerializer.deserialize[Any](columnType.extract(buffer).asInstanceOf[Array[Byte]]) - } else { - columnType.extract(buffer) - } - - assert(actual === randomRow.get(0, columnType.dataType), + columnType.extract(buffer, actual, 0) + assert(converter(actual.get(0, dataType)) === converter(randomRow.get(0, dataType)), "Extracted value didn't equal to the original one") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index ab2644eb4581d..6b7401464f46f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -25,32 +25,32 @@ import org.apache.spark.sql.test.SQLTestData._ class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ - private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize - private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning + private lazy val originalColumnBatchSize = sqlContext.conf.columnBatchSize + private lazy val originalInMemoryPartitionPruning = sqlContext.conf.inMemoryPartitionPruning override protected def beforeAll(): Unit = { super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) + sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) - val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key => + val pruningData = sparkContext.makeRDD((1 to 100).map { key => val string = if (((key - 1) / 10) % 2 == 0) null else key.toString TestData(key, string) }, 5).toDF() pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Enable in-memory table scan accumulators - ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") - ctx.cacheTable("pruningData") + sqlContext.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + sqlContext.cacheTable("pruningData") } override protected def afterAll(): Unit = { try { - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - ctx.uncacheTable("pruningData") + sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + sqlContext.uncacheTable("pruningData") } finally { super.afterAll() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 8998f5111124c..911d12e93e503 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder + test("shuffling UnsafeRows in exchange") { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index fad93b014c237..ebdab1c26d7bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.SparkFunSuite import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row, SQLConf} import org.apache.spark.sql.catalyst.InternalRow @@ -31,14 +30,14 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -class PlannerSuite extends SparkFunSuite with SharedSQLContext { +class PlannerSuite extends SharedSQLContext { import testImplicits._ setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { - val _ctx = ctx - import _ctx.planner._ + val planner = sqlContext.planner + import planner._ val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) val planned = plannedOption.getOrElse( @@ -53,8 +52,8 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { } test("unions are collapsed") { - val _ctx = ctx - import _ctx.planner._ + val planner = sqlContext.planner + import planner._ val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head val logicalUnions = query collect { case u: logical.Union => u } @@ -81,33 +80,30 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { - def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = { - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) - val fields = fieldTypes.zipWithIndex.map { - case (dataType, index) => StructField(s"c${index}", dataType, true) - } :+ StructField("key", IntegerType, true) - val schema = StructType(fields) - val row = Row.fromSeq(Seq.fill(fields.size)(null)) - val rowRDD = ctx.sparkContext.parallelize(row :: Nil) - ctx.createDataFrame(rowRDD, schema).registerTempTable("testLimit") - - val planned = sql( - """ - |SELECT l.a, l.b - |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) - """.stripMargin).queryExecution.executedPlan - - val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } - - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - - ctx.dropTempTable("testLimit") + def checkPlan(fieldTypes: Seq[DataType]): Unit = { + withTempTable("testLimit") { + val fields = fieldTypes.zipWithIndex.map { + case (dataType, index) => StructField(s"c${index}", dataType, true) + } :+ StructField("key", IntegerType, true) + val schema = StructType(fields) + val row = Row.fromSeq(Seq.fill(fields.size)(null)) + val rowRDD = sparkContext.parallelize(row :: Nil) + sqlContext.createDataFrame(rowRDD, schema).registerTempTable("testLimit") + + val planned = sql( + """ + |SELECT l.a, l.b + |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) + """.stripMargin).queryExecution.executedPlan + + val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } + val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + + assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + } } - val origThreshold = ctx.conf.autoBroadcastJoinThreshold - val simpleTypes = NullType :: BooleanType :: @@ -124,7 +120,9 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { StringType :: BinaryType :: Nil - checkPlan(simpleTypes, newThreshold = 16434) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "16434") { + checkPlan(simpleTypes) + } val complexTypes = ArrayType(DoubleType, true) :: @@ -136,36 +134,37 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { StructField("b", ArrayType(DoubleType), nullable = false), StructField("c", DoubleType, nullable = false))) :: Nil - checkPlan(complexTypes, newThreshold = 901617) - - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "901617") { + checkPlan(complexTypes) + } } test("InMemoryRelation statistics propagation") { - val origThreshold = ctx.conf.autoBroadcastJoinThreshold - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) - - testData.limit(3).registerTempTable("tiny") - sql("CACHE TABLE tiny") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "81920") { + withTempTable("tiny") { + testData.limit(3).registerTempTable("tiny") + sql("CACHE TABLE tiny") - val a = testData.as("a") - val b = ctx.table("tiny").as("b") - val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan + val a = testData.as("a") + val b = sqlContext.table("tiny").as("b") + val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan - val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } + val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + sqlContext.clearCache() + } + } } test("efficient limit -> project -> sort") { { val query = testData.select('key, 'value).sort('key).limit(2).logicalPlan - val planned = ctx.planner.TakeOrderedAndProject(query) + val planned = sqlContext.planner.TakeOrderedAndProject(query) assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) assert(planned.head.output === testData.select('key, 'value).logicalPlan.output) } @@ -175,7 +174,7 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { // into it. val query = testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan - val planned = ctx.planner.TakeOrderedAndProject(query) + val planned = sqlContext.planner.TakeOrderedAndProject(query) assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) assert(planned.head.output === testData.select('value, 'key).logicalPlan.output) } @@ -355,6 +354,55 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { } } + test("EnsureRequirements adds sort when there is no existing ordering") { + val orderingA = SortOrder(Literal(1), Ascending) + val orderingB = SortOrder(Literal(2), Ascending) + assert(orderingA != orderingB) + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputOrdering = Seq.empty) :: Nil, + requiredChildOrdering = Seq(Seq(orderingB)), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case s: TungstenSort => true; case s: Sort => true }.isEmpty) { + fail(s"Sort should have been added:\n$outputPlan") + } + } + + test("EnsureRequirements skips sort when required ordering is prefix of existing ordering") { + val orderingA = SortOrder(Literal(1), Ascending) + val orderingB = SortOrder(Literal(2), Ascending) + assert(orderingA != orderingB) + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB)) :: Nil, + requiredChildOrdering = Seq(Seq(orderingA)), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case s: TungstenSort => true; case s: Sort => true }.nonEmpty) { + fail(s"No sorts should have been added:\n$outputPlan") + } + } + + // This is a regression test for SPARK-11135 + test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") { + val orderingA = SortOrder(Literal(1), Ascending) + val orderingB = SortOrder(Literal(2), Ascending) + assert(orderingA != orderingB) + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputOrdering = Seq(orderingA)) :: Nil, + requiredChildOrdering = Seq(Seq(orderingA, orderingB)), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case s: TungstenSort => true; case s: Sort => true }.isEmpty) { + fail(s"Sort should have been added:\n$outputPlan") + } + } + // --------------------------------------------------------------------------------------------- } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index ef6ad59b71fb3..b3fceeab64cfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -21,8 +21,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StringType} +import org.apache.spark.sql.types.{ArrayType, StringType} import org.apache.spark.unsafe.types.UTF8String class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { @@ -32,27 +33,27 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { case c: ConvertToSafe => c } - private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) + private val outputsSafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(!outputsSafe.outputsUnsafeRows) private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { val plan = Limit(10, outputsUnsafe) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) } test("filter can process unsafe rows") { val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).size === 1) assert(preparedPlan.outputsUnsafeRows) } test("filter can process safe rows") { val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).isEmpty) assert(!preparedPlan.outputsUnsafeRows) } @@ -67,33 +68,33 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { test("union requires all of its input rows' formats to agree") { val plan = Union(Seq(outputsSafe, outputsUnsafe)) assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(preparedPlan.outputsUnsafeRows) } test("union can process safe rows") { val plan = Union(Seq(outputsSafe, outputsSafe)) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(!preparedPlan.outputsUnsafeRows) } test("union can process unsafe rows") { val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(preparedPlan.outputsUnsafeRows) } test("round trip with ConvertToUnsafe and ConvertToSafe") { val input = Seq(("hello", 1), ("world", 2)) checkAnswer( - ctx.createDataFrame(input), + sqlContext.createDataFrame(input), plan => ConvertToSafe(ConvertToUnsafe(plan)), input.map(Row.fromTuple) ) } test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { - SparkPlan.currentContext.set(ctx) + SparkPlan.currentContext.set(sqlContext) val schema = ArrayType(StringType) val rows = (1 to 100).map { i => InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala new file mode 100644 index 0000000000000..63639681ef80a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.Properties + +import scala.collection.parallel.CompositeThrowable + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.SQLContext + +class SQLExecutionSuite extends SparkFunSuite { + + test("concurrent query execution (SPARK-10548)") { + // Try to reproduce the issue with the old SparkContext + val conf = new SparkConf() + .setMaster("local[*]") + .setAppName("test") + val badSparkContext = new BadSparkContext(conf) + try { + testConcurrentQueryExecution(badSparkContext) + fail("unable to reproduce SPARK-10548") + } catch { + case e: IllegalArgumentException => + assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY)) + } finally { + badSparkContext.stop() + } + + // Verify that the issue is fixed with the latest SparkContext + val goodSparkContext = new SparkContext(conf) + try { + testConcurrentQueryExecution(goodSparkContext) + } finally { + goodSparkContext.stop() + } + } + + /** + * Trigger SPARK-10548 by mocking a parent and its child thread executing queries concurrently. + */ + private def testConcurrentQueryExecution(sc: SparkContext): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Initialize local properties. This is necessary for the test to pass. + sc.getLocalProperties + + // Set up a thread that runs executes a simple SQL query. + // Before starting the thread, mutate the execution ID in the parent. + // The child thread should not see the effect of this change. + var throwable: Option[Throwable] = None + val child = new Thread { + override def run(): Unit = { + try { + sc.parallelize(1 to 100).map { i => (i, i) }.toDF("a", "b").collect() + } catch { + case t: Throwable => + throwable = Some(t) + } + + } + } + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, "anything") + child.start() + child.join() + + // The throwable is thrown from the child thread so it doesn't have a helpful stack trace + throwable.foreach { t => + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + throw t + } + } + +} + +/** + * A bad [[SparkContext]] that does not clone the inheritable thread local properties + * when passing them to children threads. + */ +private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { + protected[spark] override val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index 8fa77b0fcb7b7..847c188a30333 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.SharedSQLContext class SortSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder // This test was originally added as an example of how to use [[SparkPlanTest]]; // it's not designed to be a comprehensive test of ExternalSort. @@ -35,13 +36,13 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), + Sort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), + Sort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), sortAnswers = false) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 5ab8f44faebf6..8549a6a0f6643 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -31,14 +31,7 @@ import org.apache.spark.sql.test.SQLTestUtils * class's test helper methods can be used, see [[SortSuite]]. */ private[sql] abstract class SparkPlanTest extends SparkFunSuite { - protected def _sqlContext: SQLContext - - /** - * Creates a DataFrame from a local Seq of Product. - */ - implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { - _sqlContext.implicits.localSeqToDataFrameHolder(data) - } + protected def sqlContext: SQLContext /** * Runs the plan and makes sure the answer matches the expected result. @@ -98,7 +91,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, _sqlContext) match { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -122,7 +115,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite { expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean = true): Unit = { SparkPlanTest.checkAnswer( - input, planFunction, expectedPlanFunction, sortAnswers, _sqlContext) match { + input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -149,13 +142,13 @@ object SparkPlanTest { planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean, - _sqlContext: SQLContext): Option[String] = { + sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) val expectedAnswer: Seq[Row] = try { - executePlan(expectedOutputPlan, _sqlContext) + executePlan(expectedOutputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -170,7 +163,7 @@ object SparkPlanTest { } val actualAnswer: Seq[Row] = try { - executePlan(outputPlan, _sqlContext) + executePlan(outputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -210,12 +203,12 @@ object SparkPlanTest { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean, - _sqlContext: SQLContext): Option[String] = { + sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) val sparkAnswer: Seq[Row] = try { - executePlan(outputPlan, _sqlContext) + executePlan(outputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -238,21 +231,21 @@ object SparkPlanTest { } } - private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = { + private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = _sqlContext.prepareForExecution.execute( + val resolvedPlan = sqlContext.prepareForExecution.execute( outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap - plan.transformExpressions { + plan transformExpressions { case UnresolvedAttribute(Seq(u)) => inputMap.getOrElse(u, sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) } } ) - resolvedPlan.executeCollect().toSeq + resolvedPlan.executeCollectPublic().toSeq } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala index 48c3938ff87ba..c4358f409b6ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala @@ -17,12 +17,18 @@ package org.apache.spark.sql.execution +import scala.collection.mutable + +import org.apache.spark.memory.MemoryManager import org.apache.spark.shuffle.ShuffleMemoryManager +import org.apache.spark.storage.{BlockId, BlockStatus} + /** * A [[ShuffleMemoryManager]] that can be controlled to run out of memory. */ -class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue, 4 * 1024 * 1024) { +class TestShuffleMemoryManager + extends ShuffleMemoryManager(new GrantEverythingMemoryManager, 4 * 1024 * 1024) { private var oom = false override def tryToAcquire(numBytes: Long): Long = { @@ -49,3 +55,21 @@ class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue, 4 * 1 oom = true } } + +private class GrantEverythingMemoryManager extends MemoryManager { + override def acquireExecutionMemory( + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = numBytes + override def acquireStorageMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true + override def acquireUnrollMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true + override def releaseExecutionMemory(numBytes: Long): Unit = { } + override def releaseStorageMemory(numBytes: Long): Unit = { } + override def maxExecutionMemory: Long = Long.MaxValue + override def maxStorageMemory: Long = Long.MaxValue +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 3158458edb831..7a0f0dfd2b7f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -29,15 +29,16 @@ import org.apache.spark.sql.types._ * A test suite that generates randomized data to test the [[TungstenSort]] operator. */ class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder override def beforeAll(): Unit = { super.beforeAll() - ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true) + sqlContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) } override def afterAll(): Unit = { try { - ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) + sqlContext.conf.unsetConf(SQLConf.CODEGEN_ENABLED) } finally { super.afterAll() } @@ -64,8 +65,7 @@ class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { } test("sorting updates peak execution memory") { - val sc = ctx.sparkContext - AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child), @@ -83,8 +83,8 @@ class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { val inputData = Seq.fill(1000)(randomDataGenerator()) - val inputDf = ctx.createDataFrame( - ctx.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + val inputDf = sqlContext.createDataFrame( + sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) assert(TungstenSort.supportsSchema(inputDf.schema)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index d1f0b2b1fc52f..1739798a24e0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,9 +23,10 @@ import scala.util.{Try, Random} import org.scalatest.Matchers -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} +import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} @@ -199,9 +200,7 @@ class UnsafeFixedWidthAggregationMapSuite val sorter = map.destructAndCreateExternalSorter() withClue(s"destructAndCreateExternalSorter should release memory used by the map") { - // 4096 * 16 is the initial size allocated for the pointer/prefix array in the in-mem sorter. - assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === - initialMemoryConsumption + 4096 * 16) + assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption) } // Add more keys to the sorter and make sure the results come out sorted. @@ -304,9 +303,7 @@ class UnsafeFixedWidthAggregationMapSuite val sorter = map.destructAndCreateExternalSorter() withClue(s"destructAndCreateExternalSorter should release memory used by the map") { - // 4096 * 16 is the initial size allocated for the pointer/prefix array in the in-mem sorter. - assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === - initialMemoryConsumption + 4096 * 16) + assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption) } // Add more keys to the sorter and make sure the results come out sorted. @@ -325,7 +322,7 @@ class UnsafeFixedWidthAggregationMapSuite // At here, we also test if copy is correct. iter.getKey.copy() iter.getValue.copy() - count += 1; + count += 1 } // 1 record was from the map and 4096 records were explicitly inserted. @@ -333,4 +330,48 @@ class UnsafeFixedWidthAggregationMapSuite map.free() } + + testWithMemoryLeakDetection("convert to external sorter under memory pressure (SPARK-10474)") { + val smm = ShuffleMemoryManager.createForTesting(65536) + val pageSize = 4096 + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + smm, + 128, // initial capacity + pageSize, + false // disable perf metrics + ) + + // Insert into the map until we've run out of space + val rand = new Random(42) + var hasSpace = true + while (hasSpace) { + val str = rand.nextString(1024) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + if (buf == null) { + hasSpace = false + } else { + buf.setInt(0, str.length) + } + } + + // Ensure we're actually maxed out by asserting that we can't acquire even just 1 byte + assert(smm.tryToAcquire(1) === 0) + + // Convert the map into a sorter. This used to fail before the fix for SPARK-10474 + // because we would try to acquire space for the in-memory sorter pointer array before + // actually releasing the pages despite having spilled all of them. + var sorter: UnsafeKVExternalSorter = null + try { + sorter = map.destructAndCreateExternalSorter() + } finally { + if (sorter != null) { + sorter.cleanupResources() + } + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index bd02c73a26ace..1680d7e0a85ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -17,13 +17,18 @@ package org.apache.spark.sql.execution -import java.io.{DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{File, ByteArrayInputStream, ByteArrayOutputStream} -import org.apache.spark.SparkFunSuite +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.Utils import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ +import org.apache.spark._ /** @@ -37,12 +42,18 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea } } -class UnsafeRowSerializerSuite extends SparkFunSuite { +class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { - val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow] + val converter = unsafeRowConverter(schema) + converter(row) + } + + private def unsafeRowConverter(schema: Array[DataType]): Row => UnsafeRow = { val converter = UnsafeProjection.create(schema) - converter.apply(internalRow) + (row: Row) => { + converter(CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow]) + } } test("toUnsafeRow() test helper method") { @@ -77,14 +88,67 @@ class UnsafeRowSerializerSuite extends SparkFunSuite { } test("close empty input stream") { - val baos = new ByteArrayOutputStream() - val dout = new DataOutputStream(baos) - dout.writeInt(-1) // EOF - dout.flush() - val input = new ClosableByteArrayInputStream(baos.toByteArray) + val input = new ClosableByteArrayInputStream(Array.empty) val serializer = new UnsafeRowSerializer(numFields = 2).newInstance() val deserializerIter = serializer.deserializeStream(input).asKeyValueIterator assert(!deserializerIter.hasNext) assert(input.closed) } + + test("SPARK-10466: external sorter spilling with unsafe row serializer") { + var sc: SparkContext = null + var outputFile: File = null + val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten + Utils.tryWithSafeFinally { + val conf = new SparkConf() + .set("spark.shuffle.spill.initialMemoryThreshold", "1") + .set("spark.shuffle.sort.bypassMergeThreshold", "0") + .set("spark.testing.memory", "80000") + + sc = new SparkContext("local", "test", conf) + outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "") + // prepare data + val converter = unsafeRowConverter(Array(IntegerType)) + val data = (1 to 10000).iterator.map { i => + (i, converter(Row(i))) + } + val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( + partitioner = Some(new HashPartitioner(10)), + serializer = Some(new UnsafeRowSerializer(numFields = 1))) + + // Ensure we spilled something and have to merge them later + assert(sorter.numSpills === 0) + sorter.insertAll(data) + assert(sorter.numSpills > 0) + + // Merging spilled files should not throw assertion error + val taskContext = + new TaskContextImpl(0, 0, 0, 0, null, null, InternalAccumulator.create(sc)) + taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics) + sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), taskContext, outputFile) + } { + // Clean up + if (sc != null) { + sc.stop() + } + + // restore the spark env + SparkEnv.set(oldEnv) + + if (outputFile != null) { + outputFile.delete() + } + } + } + + test("SPARK-10403: unsafe row serializer with SortShuffleManager") { + val conf = new SparkConf().set("spark.shuffle.manager", "sort") + sc = new SparkContext("local", "test", conf) + val row = Row("Hello", 123) + val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) + val rowsRDD = sc.parallelize(Seq((0, unsafeRow), (1, unsafeRow), (0, unsafeRow))) + .asInstanceOf[RDD[Product2[Int, InternalRow]]] + val shuffled = new ShuffledRowRDD(rowsRDD, new UnsafeRowSerializer(2), 2) + shuffled.count() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index 5fdb82b067728..cc0ac1b07c21a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -37,9 +37,10 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => { () => new InterpretedMutableProjection(expr, schema) } - val dummyAccum = SQLMetrics.createLongMetric(ctx.sparkContext, "dummy") - iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0, - Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) + val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy") + iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty, + 0, Seq.empty, newMutableProjection, Seq.empty, None, + dummyAccum, dummyAccum, dummyAccum, dummyAccum) val numPages = iter.getHashMap.getNumDataPages assert(numPages === 1) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 1174b27732f22..7540223bf2771 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -24,7 +24,7 @@ import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.rdd.RDD import org.scalactic.Tolerance._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType @@ -47,13 +47,15 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val factory = new JsonFactory() def enforceCorrectType(value: Any, dataType: DataType): Any = { val writer = new StringWriter() - val generator = factory.createGenerator(writer) - generator.writeObject(value) - generator.flush() + Utils.tryWithResource(factory.createGenerator(writer)) { generator => + generator.writeObject(value) + generator.flush() + } - val parser = factory.createParser(writer.toString) - parser.nextToken() - JacksonParser.convertField(factory, parser, dataType) + Utils.tryWithResource(factory.createParser(writer.toString)) { parser => + parser.nextToken() + JacksonParser.convertField(factory, parser, dataType) + } } val intNumber: Int = 2147483647 @@ -215,7 +217,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Complex field and type inferring with null in sampling") { - val jsonDF = ctx.read.json(jsonNullStruct) + val jsonDF = sqlContext.read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -234,7 +236,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Primitive field and type inferring") { - val jsonDF = ctx.read.json(primitiveFieldAndType) + val jsonDF = sqlContext.read.json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -262,7 +264,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Complex field and type inferring") { - val jsonDF = ctx.read.json(complexFieldAndType1) + val jsonDF = sqlContext.read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -361,7 +363,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("GetField operation on complex data type") { - val jsonDF = ctx.read.json(complexFieldAndType1) + val jsonDF = sqlContext.read.json(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -377,7 +379,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in primitive field values") { - val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) + val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -449,7 +451,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) + val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. @@ -502,7 +504,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in complex field values") { - val jsonDF = ctx.read.json(complexFieldValueTypeConflict) + val jsonDF = sqlContext.read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -526,7 +528,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in array elements") { - val jsonDF = ctx.read.json(arrayElementTypeConflict) + val jsonDF = sqlContext.read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -554,7 +556,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Handling missing fields") { - val jsonDF = ctx.read.json(missingFields) + val jsonDF = sqlContext.read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -573,9 +575,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalFile.toURI.toString - ctx.sparkContext.parallelize(1 to 100) + sparkContext.parallelize(1 to 100) .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) - val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path) + val jsonDF = sqlContext.read.option("samplingRatio", "0.49").json(path) val analyzed = jsonDF.queryExecution.analyzed assert( @@ -590,7 +592,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - ctx.read.schema(schema).json(path) + sqlContext.read.schema(schema).json(path) .queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.paths === Array(path)) @@ -603,7 +605,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = ctx.read.json(path) + val jsonDF = sqlContext.read.json(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -672,7 +674,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = ctx.read.schema(schema).json(path) + val jsonDF1 = sqlContext.read.schema(schema).json(path) assert(schema === jsonDF1.schema) @@ -689,7 +691,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "this is a simple string.") ) - val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType) + val jsonDF2 = sqlContext.read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) @@ -710,7 +712,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1) + val jsonWithSimpleMap = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") @@ -738,7 +740,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2) + val jsonWithComplexMap = sqlContext.read.schema(schemaWithComplexMap).json(mapType2) jsonWithComplexMap.registerTempTable("jsonWithComplexMap") @@ -764,7 +766,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = ctx.read.json(complexFieldAndType2) + val jsonDF = sqlContext.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -782,7 +784,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-3390 Complex arrays") { - val jsonDF = ctx.read.json(complexFieldAndType2) + val jsonDF = sqlContext.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -805,7 +807,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = ctx.read.json(jsonArray) + val jsonDF = sqlContext.read.json(jsonArray) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -823,64 +825,63 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Corrupt records") { // Test if we can query corrupt records. - val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - - val jsonDF = ctx.read.json(corruptRecords) - jsonDF.registerTempTable("jsonTable") - - val schema = StructType( - StructField("_unparsed", StringType, true) :: - StructField("a", StringType, true) :: - StructField("b", StringType, true) :: - StructField("c", StringType, true) :: Nil) - - assert(schema === jsonDF.schema) - - // In HiveContext, backticks should be used to access columns starting with a underscore. - checkAnswer( - sql( - """ - |SELECT a, b, c, _unparsed - |FROM jsonTable - """.stripMargin), - Row(null, null, null, "{") :: - Row(null, null, null, "") :: - Row(null, null, null, """{"a":1, b:2}""") :: - Row(null, null, null, """{"a":{, b:3}""") :: - Row("str_a_4", "str_b_4", "str_c_4", null) :: - Row(null, null, null, "]") :: Nil - ) - - checkAnswer( - sql( - """ - |SELECT a, b, c - |FROM jsonTable - |WHERE _unparsed IS NULL - """.stripMargin), - Row("str_a_4", "str_b_4", "str_c_4") - ) - - checkAnswer( - sql( - """ - |SELECT _unparsed - |FROM jsonTable - |WHERE _unparsed IS NOT NULL - """.stripMargin), - Row("{") :: - Row("") :: - Row("""{"a":1, b:2}""") :: - Row("""{"a":{, b:3}""") :: - Row("]") :: Nil - ) - - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + withTempTable("jsonTable") { + val jsonDF = sqlContext.read.json(corruptRecords) + jsonDF.registerTempTable("jsonTable") + + val schema = StructType( + StructField("_unparsed", StringType, true) :: + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) + + assert(schema === jsonDF.schema) + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + sql( + """ + |SELECT a, b, c, _unparsed + |FROM jsonTable + """.stripMargin), + Row(null, null, null, "{") :: + Row(null, null, null, "") :: + Row(null, null, null, """{"a":1, b:2}""") :: + Row(null, null, null, """{"a":{, b:3}""") :: + Row("str_a_4", "str_b_4", "str_c_4", null) :: + Row(null, null, null, "]") :: Nil + ) + + checkAnswer( + sql( + """ + |SELECT a, b, c + |FROM jsonTable + |WHERE _unparsed IS NULL + """.stripMargin), + Row("str_a_4", "str_b_4", "str_c_4") + ) + + checkAnswer( + sql( + """ + |SELECT _unparsed + |FROM jsonTable + |WHERE _unparsed IS NOT NULL + """.stripMargin), + Row("{") :: + Row("") :: + Row("""{"a":1, b:2}""") :: + Row("""{"a":{, b:3}""") :: + Row("]") :: Nil + ) + } + } } test("SPARK-4068: nulls in arrays") { - val jsonDF = ctx.read.json(nullsInArrays) + val jsonDF = sqlContext.read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -926,7 +927,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = ctx.createDataFrame(rowRDD1, schema1) + val df1 = sqlContext.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() @@ -949,7 +950,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = ctx.createDataFrame(rowRDD2, schema2) + val df3 = sqlContext.createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") val df4 = df3.toDF val result2 = df4.toJSON.collect() @@ -957,8 +958,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = ctx.read.json(primitiveFieldAndType) - val primTable = ctx.read.json(jsonDF.toJSON) + val jsonDF = sqlContext.read.json(primitiveFieldAndType) + val primTable = sqlContext.read.json(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), @@ -970,8 +971,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "this is a simple string.") ) - val complexJsonDF = ctx.read.json(complexFieldAndType1) - val compTable = ctx.read.json(complexJsonDF.toJSON) + val complexJsonDF = sqlContext.read.json(complexFieldAndType1) + val compTable = sqlContext.read.json(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1039,25 +1040,25 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Some(empty), 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(ctx) + None, None)(sqlContext) val logicalRelation0 = LogicalRelation(relation0) val relation1 = new JSONRelation( Some(singleRow), 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(ctx) + None, None)(sqlContext) val logicalRelation1 = LogicalRelation(relation1) val relation2 = new JSONRelation( Some(singleRow), 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(ctx) + None, None)(sqlContext) val logicalRelation2 = LogicalRelation(relation2) val relation3 = new JSONRelation( Some(singleRow), 1.0, Some(StructType(StructField("b", IntegerType, true) :: Nil)), - None, None)(ctx) + None, None)(sqlContext) val logicalRelation3 = LogicalRelation(relation3) assert(relation0 !== relation1) @@ -1078,18 +1079,18 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { withTempPath(dir => { val path = dir.getCanonicalFile.toURI.toString - ctx.sparkContext.parallelize(1 to 100) + sparkContext.parallelize(1 to 100) .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) val d1 = ResolvedDataSource( - ctx, + sqlContext, userSpecifiedSchema = None, partitionColumns = Array.empty[String], provider = classOf[DefaultSource].getCanonicalName, options = Map("path" -> path)) val d2 = ResolvedDataSource( - ctx, + sqlContext, userSpecifiedSchema = None, partitionColumns = Array.empty[String], provider = classOf[DefaultSource].getCanonicalName, @@ -1105,24 +1106,21 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-7565 MapType in JsonRDD") { - val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - - val schemaWithSimpleMap = StructType( - StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - try { - val temp = Utils.createTempDir().getPath - - val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) - df.write.mode("overwrite").parquet(temp) - // order of MapType is not defined - assert(ctx.read.parquet(temp).count() == 5) - - val df2 = ctx.read.json(corruptRecords) - df2.write.mode("overwrite").parquet(temp) - checkAnswer(ctx.read.parquet(temp), df2.collect()) - } finally { - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + withTempDir { dir => + val schemaWithSimpleMap = StructType( + StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) + val df = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) + + val path = dir.getAbsolutePath + df.write.mode("overwrite").parquet(path) + // order of MapType is not defined + assert(sqlContext.read.parquet(path).count() == 5) + + val df2 = sqlContext.read.json(corruptRecords) + df2.write.mode("overwrite").parquet(path) + checkAnswer(sqlContext.read.parquet(path), df2.collect()) + } } } @@ -1142,19 +1140,19 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val d1 = new File(root, "d1=1") // root/dt=1/col1=abc val p1_col1 = makePartition( - ctx.sparkContext.parallelize(2 to 5).map(i => s"""{"a": 1, "b": "str$i"}"""), + sparkContext.parallelize(2 to 5).map(i => s"""{"a": 1, "b": "str$i"}"""), d1, "col1", "abc") // root/dt=1/col1=abd val p2 = makePartition( - ctx.sparkContext.parallelize(6 to 10).map(i => s"""{"a": 1, "b": "str$i"}"""), + sparkContext.parallelize(6 to 10).map(i => s"""{"a": 1, "b": "str$i"}"""), d1, "col1", "abd") - ctx.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") + sqlContext.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") checkAnswer(sql( "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) checkAnswer(sql( @@ -1163,4 +1161,105 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) }) } + + test("backward compatibility") { + // This test we make sure our JSON support can read JSON data generated by previous version + // of Spark generated through toJSON method and JSON data source. + // The data is generated by the following program. + // Here are a few notes: + // - Spark 1.5.0 cannot save timestamp data. So, we manually added timestamp field (col13) + // in the JSON object. + // - For Spark before 1.5.1, we do not generate UDTs. So, we manually added the UDT value to + // JSON objects generated by those Spark versions (col17). + // - If the type is NullType, we do not write data out. + + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + + val dataTypes = + Seq( + StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new MyDenseVectorUDT()) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + val schema = StructType(fields) + + val constantValues = + Seq( + "a string in binary".getBytes("UTF-8"), + null, + true, + 1.toByte, + 2.toShort, + 3, + Long.MaxValue, + 0.25.toFloat, + 0.75, + new java.math.BigDecimal(s"1234.23456"), + new java.math.BigDecimal(s"1.23456"), + java.sql.Date.valueOf("2015-01-01"), + java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123"), + Seq(2, 3, 4), + Map("a string" -> 2000L), + Row(4.75.toFloat, Seq(false, true)), + new MyDenseVector(Array(0.25, 2.25, 4.25))) + val data = + Row.fromSeq(Seq("Spark " + sqlContext.sparkContext.version) ++ constantValues) :: Nil + + // Data generated by previous versions. + // scalastyle:off + val existingJSONData = + """{"col0":"Spark 1.2.2","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"16436","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: Nil + // scalastyle:on + + // Generate data for the current version. + val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data, 1), schema) + withTempPath { path => + df.write.format("json").mode("overwrite").save(path.getCanonicalPath) + + // df.toJSON will convert internal rows to external rows first and then generate + // JSON objects. While, df.write.format("json") will write internal rows directly. + val allJSON = + existingJSONData ++ + df.toJSON.collect() ++ + sparkContext.textFile(path.getCanonicalPath).collect() + + Utils.deleteRecursively(path) + sparkContext.parallelize(allJSON, 1).saveAsTextFile(path.getCanonicalPath) + + // Read data back with the schema specified. + val col0Values = + Seq( + "Spark 1.2.2", + "Spark 1.3.1", + "Spark 1.3.1", + "Spark 1.4.1", + "Spark 1.4.1", + "Spark 1.5.0", + "Spark 1.5.0", + "Spark " + sqlContext.sparkContext.version, + "Spark " + sqlContext.sparkContext.version) + val expectedResult = col0Values.map { v => + Row.fromSeq(Seq(v) ++ constantValues) + } + checkAnswer( + sqlContext.read.format("json").schema(schema).load(path.getCanonicalPath), + expectedResult + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 2864181cf91d5..713d1da1cb515 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -21,10 +21,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext private[json] trait TestJsonData { - protected def _sqlContext: SQLContext + protected def sqlContext: SQLContext def primitiveFieldAndType: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -35,7 +35,7 @@ private[json] trait TestJsonData { }""" :: Nil) def primitiveFieldValueTypeConflict: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -46,14 +46,14 @@ private[json] trait TestJsonData { "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) def jsonNullStruct: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) def complexFieldValueTypeConflict: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -64,14 +64,14 @@ private[json] trait TestJsonData { "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) def arrayElementTypeConflict: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) def missingFields: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: @@ -79,7 +79,7 @@ private[json] trait TestJsonData { """{"e":"str"}""" :: Nil) def complexFieldAndType1: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -95,7 +95,7 @@ private[json] trait TestJsonData { }""" :: Nil) def complexFieldAndType2: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -149,7 +149,7 @@ private[json] trait TestJsonData { }""" :: Nil) def mapType1: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: @@ -157,7 +157,7 @@ private[json] trait TestJsonData { """{"map": {"e": null}}""" :: Nil) def mapType2: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -166,21 +166,21 @@ private[json] trait TestJsonData { """{"map": {"f": {"field1": null}}}""" :: Nil) def nullsInArrays: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) def jsonArray: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) def corruptRecords: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -189,7 +189,7 @@ private[json] trait TestJsonData { """]""" :: Nil) def emptyRecords: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a": {}}""" :: @@ -198,7 +198,7 @@ private[json] trait TestJsonData { """]""" :: Nil) - lazy val singleRow: RDD[String] = _sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) + lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) - def empty: RDD[String] = _sqlContext.sparkContext.parallelize(Seq[String]()) + def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 91f3ce4d34c8b..0835bd123049b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -39,12 +39,13 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { val fsPath = new Path(path) - val fs = fsPath.getFileSystem(configuration) + val fs = fsPath.getFileSystem(hadoopConfiguration) val parquetFiles = fs.listStatus(fsPath, new PathFilter { override def accept(path: Path): Boolean = pathFilter(path) }).toSeq.asJava - val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) + val footers = + ParquetFileReader.readAllFootersInParallel(hadoopConfiguration, parquetFiles, true) footers.asScala.head.getParquetMetadata.getFileMetaData.getSchema } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index f067112cfca95..13fdd555a4c71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -55,7 +55,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex .where(Column(predicate)) val analyzedPredicate = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation)) => filters + case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation, _)) => filters }.flatten assert(analyzedPredicate.nonEmpty) @@ -219,7 +219,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } - test("filter pushdown - string") { + // See https://issues.apache.org/jira/browse/SPARK-11153 + ignore("filter pushdown - string") { withParquetDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate( @@ -247,7 +248,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } - test("filter pushdown - binary") { + // See https://issues.apache.org/jira/browse/SPARK-11153 + ignore("filter pushdown - binary") { implicit class IntToBinary(int: Int) { def b: Array[Byte] = int.toString.getBytes("UTF-8") } @@ -297,4 +299,21 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + + test("SPARK-10829: Filter combine partition key and attribute doesn't work in DataSource scan") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) + + // If the "part = 1" filter gets pushed down, this query will throw an exception since + // "part" is not a valid column in the actual Parquet file + checkAnswer( + sqlContext.read.parquet(path).filter("a > 0 and (part = 0 or a > 1)"), + (2 to 3).map(i => Row(i, i.toString, 1))) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 08d2b9dee99b0..72744799897be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -28,10 +28,10 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.parquet.example.data.simple.SimpleGroup import org.apache.parquet.example.data.{Group, GroupWriter} +import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext -import org.apache.parquet.hadoop.metadata.{BlockMetaData, CompressionCodecName, FileMetaData, ParquetMetadata} -import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetOutputCommitter, ParquetWriter} +import org.apache.parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata} import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} @@ -99,16 +99,18 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true")(checkParquetFile(data)) } - test("fixed-length decimals") { - def makeDecimalRDD(decimal: DecimalType): DataFrame = - sqlContext.sparkContext - .parallelize(0 to 1000) - .map(i => Tuple1(i / 100.0)) - .toDF() - // Parquet doesn't allow column names with spaces, have to add an alias here - .select($"_1" cast decimal as "dec") + testStandardAndLegacyModes("fixed-length decimals") { + def makeDecimalRDD(decimal: DecimalType): DataFrame = { + sqlContext + .range(1000) + // Parquet doesn't allow column names with spaces, have to add an alias here. + // Minus 500 here so that negative decimals are also tested. + .select((('id - 500) / 100.0) cast decimal as 'dec) + .coalesce(1) + } - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37))) { + val combinations = Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37)) + for ((precision, scale) <- combinations) { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) data.write.parquet(dir.getCanonicalPath) @@ -119,7 +121,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("date type") { def makeDateRDD(): DataFrame = - sqlContext.sparkContext + sparkContext .parallelize(0 to 1000) .map(i => Tuple1(DateTimeUtils.toJavaDate(i))) .toDF() @@ -132,22 +134,22 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } - test("map") { + testStandardAndLegacyModes("map") { val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) checkParquetFile(data) } - test("array") { + testStandardAndLegacyModes("array") { val data = (1 to 4).map(i => Tuple1(Seq(i, i + 1))) checkParquetFile(data) } - test("array and double") { + testStandardAndLegacyModes("array and double") { val data = (1 to 4).map(i => (i.toDouble, Seq(i.toDouble, (i + 1).toDouble))) checkParquetFile(data) } - test("struct") { + testStandardAndLegacyModes("struct") { val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) withParquetDataFrame(data) { df => // Structs are converted to `Row`s @@ -157,7 +159,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } - test("nested struct with array of array as field") { + testStandardAndLegacyModes("nested struct with array of array as field") { val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) withParquetDataFrame(data) { df => // Structs are converted to `Row`s @@ -167,7 +169,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } - test("nested map with struct as value type") { + testStandardAndLegacyModes("nested map with struct as value type") { val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) withParquetDataFrame(data) { df => checkAnswer(df, data.map { case Tuple1(m) => @@ -205,14 +207,14 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } test("compression codec") { - def compressionCodecFor(path: String): String = { - val codecs = ParquetTypesConverter - .readMetaData(new Path(path), Some(configuration)).getBlocks.asScala - .flatMap(_.getColumns.asScala) - .map(_.getCodec.name()) - .distinct - - assert(codecs.size === 1) + def compressionCodecFor(path: String, codecName: String): String = { + val codecs = for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConfiguration) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + + assert(codecs.distinct === Seq(codecName)) codecs.head } @@ -222,7 +224,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) { withParquetFile(data) { path => assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) { - compressionCodecFor(path) + compressionCodecFor(path, codec.name()) } } } @@ -277,16 +279,15 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("write metadata") { withTempPath { file => val path = new Path(file.toURI.toString) - val fs = FileSystem.getLocal(configuration) - val attributes = ScalaReflection.attributesFor[(Int, String)] - ParquetTypesConverter.writeMetaData(attributes, path, configuration) + val fs = FileSystem.getLocal(hadoopConfiguration) + val schema = StructType.fromAttributes(ScalaReflection.attributesFor[(Int, String)]) + writeMetadata(schema, path, hadoopConfiguration) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration)) - val actualSchema = metaData.getFileMetaData.getSchema - val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) + val expectedSchema = new CatalystSchemaConverter().convert(schema) + val actualSchema = readFooter(path, hadoopConfiguration).getFileMetaData.getSchema actualSchema.checkContains(expectedSchema) expectedSchema.checkContains(actualSchema) @@ -355,7 +356,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val path = new Path(location.getCanonicalPath) ParquetFileWriter.writeMetadataFile( - sqlContext.sparkContext.hadoopConfiguration, + sparkContext.hadoopConfiguration, path, Collections.singletonList( new Footer(path, new ParquetMetadata(fileMetadata, Collections.emptyList())))) @@ -370,12 +371,12 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } test("SPARK-6352 DirectParquetOutputCommitter") { - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) // Write to a parquet file and let it fail. // _temporary should be missing if direct output committer works. try { - configuration.set("spark.sql.parquet.output.committer.class", + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", classOf[DirectParquetOutputCommitter].getCanonicalName) sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => @@ -383,23 +384,23 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") - val fs = path.getFileSystem(configuration) + val fs = path.getFileSystem(hadoopConfiguration) assert(!fs.exists(path)) } } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } test("SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible") { - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) // Write to a parquet file and let it fail. // _temporary should be missing if direct output committer works. try { - configuration.set("spark.sql.parquet.output.committer.class", + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => @@ -407,25 +408,25 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") - val fs = path.getFileSystem(configuration) + val fs = path.getFileSystem(hadoopConfiguration) assert(!fs.exists(path)) } } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overridden") { withTempPath { dir => - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) - configuration.set( + hadoopConfiguration.set( SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter].getCanonicalName) - configuration.set( + hadoopConfiguration.set( "spark.sql.parquet.output.committer.class", classOf[JobCommitFailureParquetOutputCommitter].getCanonicalName) @@ -436,8 +437,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { assert(message === "Intentional exception for testing purposes") } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } } @@ -455,11 +456,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } test("SPARK-7837 Do not close output writer twice when commitTask() fails") { - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) // Using a output committer that always fail when committing a task, so that both // `commitTask()` and `abortTask()` are invoked. - configuration.set( + hadoopConfiguration.set( "spark.sql.parquet.output.committer.class", classOf[TaskCommitFailureParquetOutputCommitter].getCanonicalName) @@ -483,10 +484,29 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } + + test("read dictionary encoded decimals written as INT32") { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("dec-in-i32.parquet"), + sqlContext.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) + } + + test("read dictionary encoded decimals written as INT64") { + checkAnswer( + // Decimal column in this file is encoded using plain dictionary + readResourceParquetFile("dec-in-i64.parquet"), + sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) + } + + // TODO Adds test case for reading dictionary encoded decimals written as `FIXED_LEN_BYTE_ARRAY` + // The Parquet writer version Spark 1.6 and prior versions use is `PARQUET_1_0`, which doesn't + // provide dictionary encoding support for `FIXED_LEN_BYTE_ARRAY`. Should add a test here once + // we upgrade to `PARQUET_2_0`. } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index ed8bafb10c60b..3a23b8ed66808 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -465,7 +465,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { - case LogicalRelation(relation: ParquetRelation) => + case LogicalRelation(relation: ParquetRelation, _) => assert(relation.partitionSpec === PartitionSpec.emptySpec) }.getOrElse { fail(s"Expecting a ParquetRelation2, but got:\n$queryExecution") @@ -517,7 +517,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) - val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(row :: Nil), schema) + val df = sqlContext.createDataFrame(sparkContext.parallelize(row :: Nil), schema) withTempPath { dir => df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala index b290429c2a021..98333e58cada8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala @@ -17,23 +17,17 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.test.SharedSQLContext class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { - - private def readParquetProtobufFile(name: String): DataFrame = { - val url = Thread.currentThread().getContextClassLoader.getResource(name) - sqlContext.read.parquet(url.toString) - } - test("unannotated array of primitive type") { - checkAnswer(readParquetProtobufFile("old-repeated-int.parquet"), Row(Seq(1, 2, 3))) + checkAnswer(readResourceParquetFile("old-repeated-int.parquet"), Row(Seq(1, 2, 3))) } test("unannotated array of struct") { checkAnswer( - readParquetProtobufFile("old-repeated-message.parquet"), + readResourceParquetFile("old-repeated-message.parquet"), Row( Seq( Row("First inner", null, null), @@ -41,14 +35,14 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh Row(null, null, "Third inner")))) checkAnswer( - readParquetProtobufFile("proto-repeated-struct.parquet"), + readResourceParquetFile("proto-repeated-struct.parquet"), Row( Seq( Row("0 - 1", "0 - 2", "0 - 3"), Row("1 - 1", "1 - 2", "1 - 3")))) checkAnswer( - readParquetProtobufFile("proto-struct-with-array-many.parquet"), + readResourceParquetFile("proto-struct-with-array-many.parquet"), Seq( Row( Seq( @@ -66,13 +60,13 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh test("struct with unannotated array") { checkAnswer( - readParquetProtobufFile("proto-struct-with-array.parquet"), + readResourceParquetFile("proto-struct-with-array.parquet"), Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10)))) } test("unannotated array of struct with unannotated array") { checkAnswer( - readParquetProtobufFile("nested-array-struct.parquet"), + readResourceParquetFile("nested-array-struct.parquet"), Seq( Row(2, Seq(Row(1, Seq(Row(3))))), Row(5, Seq(Row(4, Seq(Row(6))))), @@ -81,7 +75,7 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh test("unannotated array of string") { checkAnswer( - readParquetProtobufFile("proto-repeated-string.parquet"), + readResourceParquetFile("proto-repeated-string.parquet"), Seq( Row(Seq("hello", "world")), Row(Seq("good", "bye")), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index a379523d67f80..baff7f5752a75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -22,6 +22,9 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{TableIdentifier, InternalRow} +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -30,6 +33,7 @@ import org.apache.spark.util.Utils * A test suite that tests various Parquet queries. */ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext { + import testImplicits._ test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { @@ -40,22 +44,22 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple)) + checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) } - ctx.catalog.unregisterTable(Seq("tmp")) + sqlContext.catalog.unregisterTable(TableIdentifier("tmp")) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(ctx.table("t"), data.map(Row.fromTuple)) + checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) } - ctx.catalog.unregisterTable(Seq("tmp")) + sqlContext.catalog.unregisterTable(TableIdentifier("tmp")) } test("self-join") { @@ -118,9 +122,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val schema = StructType(List(StructField("d", DecimalType(18, 0), false), StructField("time", TimestampType, false)).toArray) withTempPath { file => - val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema) + val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema) df.write.parquet(file.getCanonicalPath) - val df2 = ctx.read.parquet(file.getCanonicalPath) + val df2 = sqlContext.read.parquet(file.getCanonicalPath) checkAnswer(df2, df.collect().toSeq) } } @@ -129,12 +133,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) // delete summary files, so if we don't merge part-files, one column will not be included. Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) - assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber) + assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -153,9 +157,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) - assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber) + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -171,19 +175,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { withTempPath { dir => val basePath = dir.getCanonicalPath - ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) // Disables the global SQL option for schema merging withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { assertResult(2) { // Disables schema merging via data source option - ctx.read.option("mergeSchema", "false").parquet(basePath).columns.length + sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length } assertResult(3) { // Enables schema merging via data source option - ctx.read.option("mergeSchema", "true").parquet(basePath).columns.length + sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length } } } @@ -193,7 +197,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val basePath = dir.getCanonicalPath val schema = StructType(Array(StructField("name", DecimalType(10, 5), false))) - val rowRDD = sqlContext.sparkContext.parallelize(Array(Row(Decimal("67123.45")))) + val rowRDD = sparkContext.parallelize(Array(Row(Decimal("67123.45")))) val df = sqlContext.createDataFrame(rowRDD, schema) df.write.parquet(basePath) @@ -203,9 +207,6 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } test("SPARK-10005 Schema merging for nested struct") { - val sqlContext = _sqlContext - import sqlContext.implicits._ - withTempPath { dir => val path = dir.getCanonicalPath @@ -230,54 +231,168 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } - test("SPARK-10301 Clipping nested structs in requested schema") { + test("SPARK-10301 requested schema clipping - same schema") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 1L))) + } + } + + // This test case is ignored because of parquet-mr bug PARQUET-370 + ignore("SPARK-10301 requested schema clipping - schemas with disjoint sets of fields") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(null, null))) + } + } + + test("SPARK-10301 requested schema clipping - requested schema contains physical schema") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true) + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 1L, null, null))) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'd', id + 3) AS s").coalesce(1) + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true) + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, null, null, 3L))) + } + } + + test("SPARK-10301 requested schema clipping - physical schema contains requested schema") { withTempPath { dir => val path = dir.getCanonicalPath val df = sqlContext .range(1) - .selectExpr("NAMED_STRUCT('a', id, 'b', id) AS s") + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s") .coalesce(1) - df.write.mode("append").parquet(path) + df.write.parquet(path) - val userDefinedSchema = new StructType() - .add("s", new StructType().add("a", LongType, nullable = true), nullable = true) + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true), + nullable = true) checkAnswer( sqlContext.read.schema(userDefinedSchema).parquet(path), - Row(Row(0))) + Row(Row(0L, 1L))) } withTempPath { dir => val path = dir.getCanonicalPath - - val df1 = sqlContext + val df = sqlContext .range(1) - .selectExpr("NAMED_STRUCT('a', id, 'b', id) AS s") + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s") .coalesce(1) - val df2 = sqlContext - .range(1, 2) - .selectExpr("NAMED_STRUCT('b', id, 'c', id) AS s") + df.write.parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("a", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0L, 3L))) + } + } + + test("SPARK-10301 requested schema clipping - schemas overlap but don't contain each other") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") .coalesce(1) - df1.write.parquet(path) - df2.write.mode(SaveMode.Append).parquet(path) + df.write.parquet(path) - val userDefinedSchema = new StructType() - .add("s", - new StructType() - .add("a", LongType, nullable = true) - .add("c", LongType, nullable = true), - nullable = true) + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("b", LongType, nullable = true) + .add("c", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) checkAnswer( sqlContext.read.schema(userDefinedSchema).parquet(path), - Seq( - Row(Row(0, null)), - Row(Row(null, 1)))) + Row(Row(1L, 2L, null))) } + } + test("SPARK-10301 requested schema clipping - deeply nested struct") { withTempPath { dir => val path = dir.getCanonicalPath @@ -306,4 +421,176 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext Row(Row(Seq(Row(0, null))))) } } + + test("SPARK-10301 requested schema clipping - out of order") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") + .coalesce(1) + + val df2 = sqlContext + .range(1, 2) + .selectExpr("NAMED_STRUCT('c', id + 2, 'b', id + 1, 'd', id + 3) AS s") + .coalesce(1) + + df1.write.parquet(path) + df2.write.mode(SaveMode.Append).parquet(path) + + val userDefinedSchema = new StructType() + .add("s", + new StructType() + .add("a", LongType, nullable = true) + .add("b", LongType, nullable = true) + .add("d", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Seq( + Row(Row(0, 1, null)), + Row(Row(null, 2, 4)))) + } + } + + test("SPARK-10301 requested schema clipping - schema merging") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'c', id + 2) AS s") + .coalesce(1) + + val df2 = sqlContext + .range(1, 2) + .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") + .coalesce(1) + + df1.write.mode(SaveMode.Append).parquet(path) + df2.write.mode(SaveMode.Append).parquet(path) + + checkAnswer( + sqlContext + .read + .option("mergeSchema", "true") + .parquet(path) + .selectExpr("s.a", "s.b", "s.c"), + Seq( + Row(0, null, 2), + Row(1, 2, 3))) + } + } + + testStandardAndLegacyModes("SPARK-10301 requested schema clipping - UDT") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext + .range(1) + .selectExpr( + """NAMED_STRUCT( + | 'f0', CAST(id AS STRING), + | 'f1', NAMED_STRUCT( + | 'a', CAST(id + 1 AS INT), + | 'b', CAST(id + 2 AS LONG), + | 'c', CAST(id + 3.5 AS DOUBLE) + | ) + |) AS s + """.stripMargin) + .coalesce(1) + + df.write.mode(SaveMode.Append).parquet(path) + + val userDefinedSchema = + new StructType() + .add( + "s", + new StructType() + .add("f1", new NestedStructUDT, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(NestedStruct(1, 2L, 3.5D)))) + } + } + + test("expand UDT in StructType") { + val schema = new StructType().add("n", new NestedStructUDT, nullable = true) + val expected = new StructType().add("n", new NestedStructUDT().sqlType, nullable = true) + assert(CatalystReadSupport.expandUDT(schema) === expected) + } + + test("expand UDT in ArrayType") { + val schema = new StructType().add( + "n", + ArrayType( + elementType = new NestedStructUDT, + containsNull = false), + nullable = true) + + val expected = new StructType().add( + "n", + ArrayType( + elementType = new NestedStructUDT().sqlType, + containsNull = false), + nullable = true) + + assert(CatalystReadSupport.expandUDT(schema) === expected) + } + + test("expand UDT in MapType") { + val schema = new StructType().add( + "n", + MapType( + keyType = IntegerType, + valueType = new NestedStructUDT, + valueContainsNull = false), + nullable = true) + + val expected = new StructType().add( + "n", + MapType( + keyType = IntegerType, + valueType = new NestedStructUDT().sqlType, + valueContainsNull = false), + nullable = true) + + assert(CatalystReadSupport.expandUDT(schema) === expected) + } +} + +object TestingUDT { + @SQLUserDefinedType(udt = classOf[NestedStructUDT]) + case class NestedStruct(a: Integer, b: Long, c: Double) + + class NestedStructUDT extends UserDefinedType[NestedStruct] { + override def sqlType: DataType = + new StructType() + .add("a", IntegerType, nullable = true) + .add("b", LongType, nullable = false) + .add("c", DoubleType, nullable = false) + + override def serialize(obj: Any): Any = { + val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType)) + obj match { + case n: NestedStruct => + row.setInt(0, n.a) + row.setLong(1, n.b) + row.setDouble(2, n.c) + } + } + + override def userClass: Class[NestedStruct] = classOf[NestedStruct] + + override def deserialize(datum: Any): NestedStruct = { + datum match { + case row: InternalRow => + NestedStruct(row.getInt(0), row.getLong(1), row.getDouble(2)) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 28c59a4abdd76..60fa81b1ab819 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -22,7 +22,6 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.parquet.schema.MessageTypeParser -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -35,32 +34,29 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { protected def testSchemaInference[T <: Product: ClassTag: TypeTag]( testName: String, messageType: String, - binaryAsString: Boolean = true, - int96AsTimestamp: Boolean = true, - followParquetFormatSpec: Boolean = false, - isThriftDerived: Boolean = false): Unit = { + binaryAsString: Boolean, + int96AsTimestamp: Boolean, + writeLegacyParquetFormat: Boolean): Unit = { testSchema( testName, StructType.fromAttributes(ScalaReflection.attributesFor[T]), messageType, binaryAsString, int96AsTimestamp, - followParquetFormatSpec, - isThriftDerived) + writeLegacyParquetFormat) } protected def testParquetToCatalyst( testName: String, sqlSchema: StructType, parquetSchema: String, - binaryAsString: Boolean = true, - int96AsTimestamp: Boolean = true, - followParquetFormatSpec: Boolean = false, - isThriftDerived: Boolean = false): Unit = { + binaryAsString: Boolean, + int96AsTimestamp: Boolean, + writeLegacyParquetFormat: Boolean): Unit = { val converter = new CatalystSchemaConverter( assumeBinaryIsString = binaryAsString, assumeInt96IsTimestamp = int96AsTimestamp, - followParquetFormatSpec = followParquetFormatSpec) + writeLegacyParquetFormat = writeLegacyParquetFormat) test(s"sql <= parquet: $testName") { val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) @@ -78,14 +74,13 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { testName: String, sqlSchema: StructType, parquetSchema: String, - binaryAsString: Boolean = true, - int96AsTimestamp: Boolean = true, - followParquetFormatSpec: Boolean = false, - isThriftDerived: Boolean = false): Unit = { + binaryAsString: Boolean, + int96AsTimestamp: Boolean, + writeLegacyParquetFormat: Boolean): Unit = { val converter = new CatalystSchemaConverter( assumeBinaryIsString = binaryAsString, assumeInt96IsTimestamp = int96AsTimestamp, - followParquetFormatSpec = followParquetFormatSpec) + writeLegacyParquetFormat = writeLegacyParquetFormat) test(s"sql => parquet: $testName") { val actual = converter.convert(sqlSchema) @@ -99,10 +94,9 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { testName: String, sqlSchema: StructType, parquetSchema: String, - binaryAsString: Boolean = true, - int96AsTimestamp: Boolean = true, - followParquetFormatSpec: Boolean = false, - isThriftDerived: Boolean = false): Unit = { + binaryAsString: Boolean, + int96AsTimestamp: Boolean, + writeLegacyParquetFormat: Boolean): Unit = { testCatalystToParquet( testName, @@ -110,8 +104,7 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema, binaryAsString, int96AsTimestamp, - followParquetFormatSpec, - isThriftDerived) + writeLegacyParquetFormat) testParquetToCatalyst( testName, @@ -119,8 +112,7 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { parquetSchema, binaryAsString, int96AsTimestamp, - followParquetFormatSpec, - isThriftDerived) + writeLegacyParquetFormat) } } @@ -137,7 +129,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | optional binary _6; |} """.stripMargin, - binaryAsString = false) + binaryAsString = false, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[(Byte, Short, Int, Long, java.sql.Date)]( "logical integral types", @@ -149,7 +143,10 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | required int64 _4 (INT_64); | optional int32 _5 (DATE); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[String]]( "string", @@ -158,7 +155,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | optional binary _1 (UTF8); |} """.stripMargin, - binaryAsString = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[String]]( "binary enum as string", @@ -166,7 +165,10 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { |message root { | optional binary _1 (ENUM); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[Seq[Int]]]( "non-nullable array - non-standard", @@ -176,7 +178,10 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | repeated int32 array; | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[Seq[Int]]]( "non-nullable array - standard", @@ -189,7 +194,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchemaInference[Tuple1[Seq[Integer]]]( "nullable array - non-standard", @@ -197,11 +204,14 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { |message root { | optional group _1 (LIST) { | repeated group bag { - | optional int32 array_element; + | optional int32 array; | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[Seq[Integer]]]( "nullable array - standard", @@ -214,7 +224,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchemaInference[Tuple1[Map[Int, String]]]( "map - standard", @@ -228,7 +240,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchemaInference[Tuple1[Map[Int, String]]]( "map - non-standard", @@ -241,7 +255,10 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[Pair[Int, String]]]( "struct", @@ -253,7 +270,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( "deeply nested type - non-standard", @@ -266,7 +285,7 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | optional binary _1 (UTF8); | optional group _2 (LIST) { | repeated group bag { - | optional group array_element { + | optional group array { | required int32 _1; | required double _2; | } @@ -276,7 +295,10 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( "deeply nested type - standard", @@ -300,7 +322,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchemaInference[(Option[Int], Map[Int, Option[Double]])]( "optional types", @@ -315,36 +339,9 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) - - // Parquet files generated by parquet-thrift are already handled by the schema converter, but - // let's leave this test here until both read path and write path are all updated. - ignore("thrift generated parquet schema") { - // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated - // as expected from attributes - testSchemaInference[( - Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( - "thrift generated parquet schema", - """ - |message root { - | optional binary _1 (UTF8); - | optional binary _2 (UTF8); - | optional binary _3 (UTF8); - | optional group _4 (LIST) { - | repeated int32 _4_tuple; - | } - | optional group _5 (MAP) { - | repeated group map (MAP_KEY_VALUE) { - | required binary key (UTF8); - | optional group value (LIST) { - | repeated int32 value_tuple; - | } - | } - | } - |} - """.stripMargin, - isThriftDerived = true) - } + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) } class ParquetSchemaSuite extends ParquetSchemaTest { @@ -360,8 +357,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest { val jsonString = """{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]}""" // scalastyle:on - val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) - val fromJson = ParquetTypesConverter.convertFromString(jsonString) + val fromCaseClassString = StructType.fromString(caseClassString) + val fromJson = StructType.fromString(jsonString) (fromCaseClassString, fromJson).zipped.foreach { (a, b) => assert(a.name == b.name) @@ -470,7 +467,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with nullable element type - 2", @@ -486,7 +486,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", @@ -499,7 +502,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 2", @@ -512,7 +518,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 3", @@ -523,7 +532,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | repeated int32 element; | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 4", @@ -544,7 +556,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 5 - parquet-avro style", @@ -563,7 +578,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type - 6 - parquet-thrift style", @@ -582,7 +600,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type 7 - " + @@ -592,7 +613,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | repeated int32 f1; |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: LIST with non-nullable element type 8 - " + @@ -612,7 +636,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | required int32 c2; | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) // ======================================================= // Tests for converting Catalyst ArrayType to Parquet LIST @@ -633,7 +660,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testCatalystToParquet( "Backwards-compatibility: LIST with nullable element type - 2 - prior to 1.4.x", @@ -645,11 +674,14 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional group f1 (LIST) { | repeated group bag { - | optional int32 array_element; + | optional int32 array; | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testCatalystToParquet( "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", @@ -666,7 +698,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testCatalystToParquet( "Backwards-compatibility: LIST with non-nullable element type - 2 - prior to 1.4.x", @@ -680,7 +714,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | repeated int32 array; | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) // ==================================================== // Tests for converting Parquet Map to Catalyst MapType @@ -701,7 +738,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: MAP with non-nullable value type - 2", @@ -718,7 +758,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: MAP with non-nullable value type - 3 - prior to 1.4.x", @@ -735,7 +778,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: MAP with nullable value type - 1 - standard", @@ -752,7 +798,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: MAP with nullable value type - 2", @@ -769,7 +818,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testParquetToCatalyst( "Backwards-compatibility: MAP with nullable value type - 3 - parquet-avro style", @@ -786,7 +838,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) // ==================================================== // Tests for converting Catalyst MapType to Parquet Map @@ -808,7 +863,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testCatalystToParquet( "Backwards-compatibility: MAP with non-nullable value type - 2 - prior to 1.4.x", @@ -825,7 +882,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testCatalystToParquet( "Backwards-compatibility: MAP with nullable value type - 1 - standard", @@ -843,7 +903,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testCatalystToParquet( "Backwards-compatibility: MAP with nullable value type - 3 - prior to 1.4.x", @@ -860,7 +922,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | } | } |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) // ================================= // Tests for conversion for decimals @@ -873,7 +938,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional int32 f1 (DECIMAL(1, 0)); |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchema( "DECIMAL(8, 3) - standard", @@ -882,7 +949,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional int32 f1 (DECIMAL(8, 3)); |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchema( "DECIMAL(9, 3) - standard", @@ -891,7 +960,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional int32 f1 (DECIMAL(9, 3)); |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchema( "DECIMAL(18, 3) - standard", @@ -900,7 +971,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional int64 f1 (DECIMAL(18, 3)); |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchema( "DECIMAL(19, 3) - standard", @@ -909,7 +982,9 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional fixed_len_byte_array(9) f1 (DECIMAL(19, 3)); |} """.stripMargin, - followParquetFormatSpec = true) + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = false) testSchema( "DECIMAL(1, 0) - prior to 1.4.x", @@ -917,7 +992,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional fixed_len_byte_array(1) f1 (DECIMAL(1, 0)); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchema( "DECIMAL(8, 3) - prior to 1.4.x", @@ -925,7 +1003,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional fixed_len_byte_array(4) f1 (DECIMAL(8, 3)); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchema( "DECIMAL(9, 3) - prior to 1.4.x", @@ -933,7 +1014,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional fixed_len_byte_array(5) f1 (DECIMAL(9, 3)); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) testSchema( "DECIMAL(18, 3) - prior to 1.4.x", @@ -941,7 +1025,10 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """message root { | optional fixed_len_byte_array(8) f1 (DECIMAL(18, 3)); |} - """.stripMargin) + """.stripMargin, + binaryAsString = true, + int96AsTimestamp = true, + writeLegacyParquetFormat = true) private def testSchemaClipping( testName: String, @@ -1012,12 +1099,17 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """.stripMargin, catalystSchema = { - val f11Type = new StructType().add("f011", DoubleType, nullable = true) - val f01Type = ArrayType(StringType, containsNull = false) + val f00Type = ArrayType(StringType, containsNull = false) + val f01Type = ArrayType( + new StructType() + .add("f011", DoubleType, nullable = true), + containsNull = false) + val f0Type = new StructType() - .add("f00", f01Type, nullable = false) - .add("f01", f11Type, nullable = false) + .add("f00", f00Type, nullable = false) + .add("f01", f01Type, nullable = false) val f1Type = ArrayType(IntegerType, containsNull = true) + new StructType() .add("f0", f0Type, nullable = false) .add("f1", f1Type, nullable = true) @@ -1046,7 +1138,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { parquetSchema = """message root { | required group f0 { - | optional group f00 { + | optional group f00 (LIST) { | repeated binary f00_tuple (UTF8); | } | @@ -1061,13 +1153,13 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """.stripMargin, catalystSchema = { - val f11ElementType = new StructType() + val f01ElementType = new StructType() .add("f011", DoubleType, nullable = true) .add("f012", LongType, nullable = true) val f0Type = new StructType() - .add("f00", ArrayType(StringType, containsNull = false), nullable = false) - .add("f01", ArrayType(f11ElementType, containsNull = false), nullable = false) + .add("f00", ArrayType(StringType, containsNull = false), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) new StructType().add("f0", f0Type, nullable = false) }, @@ -1075,7 +1167,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { expectedSchema = """message root { | required group f0 { - | optional group f00 { + | optional group f00 (LIST) { | repeated binary f00_tuple (UTF8); | } | @@ -1095,7 +1187,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { parquetSchema = """message root { | required group f0 { - | optional group f00 { + | optional group f00 (LIST) { | repeated binary array (UTF8); | } | @@ -1110,13 +1202,13 @@ class ParquetSchemaSuite extends ParquetSchemaTest { """.stripMargin, catalystSchema = { - val f11ElementType = new StructType() + val f01ElementType = new StructType() .add("f011", DoubleType, nullable = true) .add("f012", LongType, nullable = true) val f0Type = new StructType() - .add("f00", ArrayType(StringType, containsNull = false), nullable = false) - .add("f01", ArrayType(f11ElementType, containsNull = false), nullable = false) + .add("f00", ArrayType(StringType, containsNull = false), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) new StructType().add("f0", f0Type, nullable = false) }, @@ -1124,7 +1216,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { expectedSchema = """message root { | required group f0 { - | optional group f00 { + | optional group f00 (LIST) { | repeated binary array (UTF8); | } | @@ -1236,6 +1328,63 @@ class ParquetSchemaSuite extends ParquetSchemaTest { |} """.stripMargin) + testSchemaClipping( + "standard array", + + parquetSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) { + | repeated group list { + | required group element { + | optional int32 f010; + | optional double f011; + | } + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = false), nullable = true) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional group f00 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | + | optional group f01 (LIST) { + | repeated group list { + | required group element { + | optional double f011; + | optional int64 f012; + | } + | } + | } + | } + |} + """.stripMargin) + testSchemaClipping( "empty requested schema", @@ -1251,4 +1400,160 @@ class ParquetSchemaSuite extends ParquetSchemaTest { catalystSchema = new StructType(), expectedSchema = "message root {}") + + testSchemaClipping( + "disjoint field sets", + + parquetSchema = + """message root { + | required group f0 { + | required int32 f00; + | required int64 f01; + | } + |} + """.stripMargin, + + catalystSchema = + new StructType() + .add( + "f0", + new StructType() + .add("f02", FloatType, nullable = true) + .add("f03", DoubleType, nullable = true), + nullable = true), + + expectedSchema = + """message root { + | required group f0 { + | optional float f02; + | optional double f03; + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-avro style map", + + parquetSchema = + """message root { + | required group f0 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required group value { + | required int32 value_f0; + | required int64 value_f1; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val valueType = + new StructType() + .add("value_f1", LongType, nullable = false) + .add("value_f2", DoubleType, nullable = false) + + val f0Type = MapType(IntegerType, valueType, valueContainsNull = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required group value { + | required int64 value_f1; + | required double value_f2; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard map", + + parquetSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required int32 key; + | required group value { + | required int32 value_f0; + | required int64 value_f1; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val valueType = + new StructType() + .add("value_f1", LongType, nullable = false) + .add("value_f2", DoubleType, nullable = false) + + val f0Type = MapType(IntegerType, valueType, valueContainsNull = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required int32 key; + | required group value { + | required int64 value_f1; + | required double value_f2; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "standard map with complex key", + + parquetSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required group key { + | required int32 value_f0; + | required int64 value_f1; + | } + | required int32 value; + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val keyType = + new StructType() + .add("value_f1", LongType, nullable = false) + .add("value_f2", DoubleType, nullable = false) + + val f0Type = MapType(keyType, IntegerType, valueContainsNull = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 (MAP) { + | repeated group key_value { + | required group key { + | required int64 value_f1; + | required double value_f2; + | } + | required int32 value; + | } + | } + |} + """.stripMargin) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 5dbc7d1630f27..8ffb01fc5b584 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -19,11 +19,19 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.metadata.{BlockMetaData, FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} + import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLConf, SaveMode} /** * A helper trait that provides convenient facilities for Parquet testing. @@ -33,7 +41,6 @@ import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ private[sql] trait ParquetTest extends SQLTestUtils { - protected def _sqlContext: SQLContext /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` @@ -43,7 +50,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - _sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) + sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -55,7 +62,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => f(_sqlContext.read.parquet(path))) + withParquetFile(data)(path => f(sqlContext.read.parquet(path))) } /** @@ -67,14 +74,14 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T], tableName: String) (f: => Unit): Unit = { withParquetDataFrame(data) { df => - _sqlContext.registerDataFrameAsTable(df, tableName) + sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - _sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( @@ -98,4 +105,43 @@ private[sql] trait ParquetTest extends SQLTestUtils { assert(partDir.mkdirs(), s"Couldn't create directory $partDir") partDir } + + protected def writeMetadata( + schema: StructType, path: Path, configuration: Configuration): Unit = { + val parquetSchema = new CatalystSchemaConverter().convert(schema) + val extraMetadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> schema.json).asJava + val createdBy = s"Apache Spark ${org.apache.spark.SPARK_VERSION}" + val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, createdBy) + val parquetMetadata = new ParquetMetadata(fileMetadata, Seq.empty[BlockMetaData].asJava) + val footer = new Footer(path, parquetMetadata) + ParquetFileWriter.writeMetadataFile(configuration, path, Seq(footer).asJava) + } + + protected def readAllFootersWithoutSummaryFiles( + path: Path, configuration: Configuration): Seq[Footer] = { + val fs = path.getFileSystem(configuration) + ParquetFileReader.readAllFootersInParallel(configuration, fs.getFileStatus(path)).asScala.toSeq + } + + protected def readFooter(path: Path, configuration: Configuration): ParquetMetadata = { + ParquetFileReader.readFooter( + configuration, + new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE), + ParquetMetadataConverter.NO_FILTER) + } + + protected def testStandardAndLegacyModes(testName: String)(f: => Unit): Unit = { + test(s"Standard mode - $testName") { + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false") { f } + } + + test(s"Legacy mode - $testName") { + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { f } + } + } + + protected def readResourceParquetFile(name: String): DataFrame = { + val url = Thread.currentThread().getContextClassLoader.getResource(name) + sqlContext.read.parquet(url.toString) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala new file mode 100644 index 0000000000000..0a2306c06646c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.text + +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.util.Utils + + +class TextSuite extends QueryTest with SharedSQLContext { + + test("reading text file") { + verifyFrame(sqlContext.read.format("text").load(testFile)) + } + + test("SQLContext.read.text() API") { + verifyFrame(sqlContext.read.text(testFile)) + } + + test("writing") { + val df = sqlContext.read.text(testFile) + + val tempFile = Utils.createTempDir() + tempFile.delete() + df.write.text(tempFile.getCanonicalPath) + verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath)) + + Utils.deleteRecursively(tempFile) + } + + test("error handling for invalid schema") { + val tempFile = Utils.createTempDir() + tempFile.delete() + + val df = sqlContext.range(2) + intercept[AnalysisException] { + df.write.text(tempFile.getCanonicalPath) + } + + intercept[AnalysisException] { + sqlContext.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath) + } + } + + private def testFile: String = { + Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString + } + + /** Verifies data and schema. */ + private def verifyFrame(df: DataFrame): Unit = { + // schema + assert(df.schema == new StructType().add("text", StringType)) + + // verify content + val data = df.collect() + assert(data(0) == Row("This is a test file for the text data source")) + assert(data(1) == Row("1+1")) + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + // scalastyle:off + assert(data(2) == Row("数据砖头")) + // scalastyle:on + assert(data(3) == Row("\"doh\"")) + assert(data.length == 4) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 53a0e53fd7719..dcbfdca71acb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -33,8 +33,7 @@ import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} * without serializing the hashed relation, which does not happen in local mode. */ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { - private var sc: SparkContext = null - private var sqlContext: SQLContext = null + protected var sqlContext: SQLContext = null /** * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. @@ -44,15 +43,14 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { val conf = new SparkConf() .setMaster("local-cluster[2,1,1024]") .setAppName("testing") - sc = new SparkContext(conf) + val sc = new SparkContext(conf) sqlContext = new SQLContext(sc) sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) } override def afterAll(): Unit = { - sc.stop() - sc = null + sqlContext.sparkContext.stop() sqlContext = null } @@ -60,7 +58,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { * Test whether the specified broadcast join updates the peak execution memory accumulator. */ private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { - AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) { + AccumulatorSuite.verifyPeakExecutionMemorySet(sqlContext.sparkContext, name) { val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") // Comparison at the end is for broadcast left semi join diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 4c9187a9a7106..e5fd9e277fc61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -37,7 +37,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("GeneralHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") + val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) @@ -53,7 +53,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("UniqueKeyHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") + val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) @@ -73,7 +73,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") + val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()).toArray diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index cc649b9bd4c45..da58e96f3e6f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -27,9 +27,10 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder - private lazy val myUpperCaseData = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val myUpperCaseData = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(1, "A"), Row(2, "B"), Row(3, "C"), @@ -39,8 +40,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { Row(null, "G") )), new StructType().add("N", IntegerType).add("L", StringType)) - private lazy val myLowerCaseData = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val myLowerCaseData = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(1, "a"), Row(2, "b"), Row(3, "c"), @@ -211,4 +212,18 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { ) } + { + lazy val left = Seq((1, Some(0)), (2, None)).toDF("a", "b") + lazy val right = Seq((1, Some(0)), (2, None)).toDF("a", "b") + testInnerJoin( + "inner join, null safe", + left, + right, + () => (left.col("b") <=> right.col("b")).expr, + Seq( + (1, 0, 1, 0), + (2, null, 2, null) + ) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index a1a617d7b7398..09e0237a7cc50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { - private lazy val left = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val left = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(1, 2.0), Row(2, 100.0), Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches @@ -40,8 +40,8 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { Row(null, null) )), new StructType().add("a", IntegerType).add("b", DoubleType)) - private lazy val right = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val right = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(0, 0.0), Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches Row(2, -1.0), @@ -76,37 +76,37 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { test(s"$testName using ShuffledHashOuterJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext).apply( - ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } if (joinType != FullOuter) { test(s"$testName using BroadcastHashOuterJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } + } - test(s"$testName using SortMergeOuterJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext).apply( - SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = false) - } + test(s"$testName using SortMergeOuterJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(sqlContext).apply( + SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index baa86e320d986..3afd762942bcf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { - private lazy val left = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val left = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), @@ -40,8 +40,8 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { Row(6, null) )), new StructType().add("a", IntegerType).add("b", DoubleType)) - private lazy val right = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val right = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(2, 3.0), Row(2, 3.0), Row(3, 2.0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala new file mode 100644 index 0000000000000..efc3227dd60d8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala @@ -0,0 +1,68 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation + +/** + * A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]]. + */ +private[local] case class DummyNode( + output: Seq[Attribute], + relation: LocalRelation, + conf: SQLConf) + extends LocalNode(conf) { + + import DummyNode._ + + private var index: Int = CLOSED + private val input: Seq[InternalRow] = relation.data + + def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) { + this(output, LocalRelation.fromProduct(output, data), conf) + } + + def isOpen: Boolean = index != CLOSED + + override def children: Seq[LocalNode] = Seq.empty + + override def open(): Unit = { + index = -1 + } + + override def next(): Boolean = { + index += 1 + index < input.size + } + + override def fetch(): InternalRow = { + assert(index >= 0 && index < input.size) + input(index) + } + + override def close(): Unit = { + index = CLOSED + } +} + +private object DummyNode { + val CLOSED: Int = Int.MinValue +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala new file mode 100644 index 0000000000000..bbd94d8da2d11 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala @@ -0,0 +1,49 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.dsl.expressions._ + + +class ExpandNodeSuite extends LocalNodeTest { + + private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = { + val inputNode = new DummyNode(kvIntAttributes, inputData) + val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v)) + val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode) + val resolvedNode = resolveExpressions(expandNode) + val expectedOutput = { + val firstHalf = inputData.map { case (k, v) => (k + v, k - v) } + val secondHalf = inputData.map { case (k, v) => (k * v, k / v) } + firstHalf ++ secondHalf + } + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput.toSet === expectedOutput.toSet) + } + + test("empty") { + testExpand() + } + + test("basic") { + testExpand((1 to 100).map { i => (i, i * 1000) }.toArray) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala index 07209f3779248..4eadce646d379 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -17,25 +17,29 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.catalyst.dsl.expressions._ -class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { - test("basic") { - val condition = (testData.col("key") % 2) === 0 - checkAnswer( - testData, - node => FilterNode(condition.expr, node), - testData.filter(condition).collect() - ) +class FilterNodeSuite extends LocalNodeTest { + + private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = { + val cond = 'k % 2 === 0 + val inputNode = new DummyNode(kvIntAttributes, inputData) + val filterNode = new FilterNode(conf, cond, inputNode) + val resolvedNode = resolveExpressions(filterNode) + val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 } + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - val condition = (emptyTestData.col("key") % 2) === 0 - checkAnswer( - emptyTestData, - node => FilterNode(condition.expr, node), - emptyTestData.filter(condition).collect() - ) + testFilter() + } + + test("basic") { + testFilter((1 to 100).map { i => (i, i) }.toArray) } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala new file mode 100644 index 0000000000000..8c2e78b2a9db7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -0,0 +1,160 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.mockito.Mockito.{mock, when} + +import org.apache.spark.broadcast.TorrentBroadcast +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, UnsafeProjection, Expression} +import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide} + +class HashJoinNodeSuite extends LocalNodeTest { + + // Test all combinations of the two dimensions: with/out unsafe and build sides + private val maybeUnsafeAndCodegen = Seq(false, true) + private val buildSides = Seq(BuildLeft, BuildRight) + maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => + buildSides.foreach { buildSide => + testJoin(unsafeAndCodegen, buildSide) + } + } + + /** + * Builds a [[HashedRelation]] based on a resolved `buildKeys` + * and a resolved `buildNode`. + */ + private def buildHashedRelation( + conf: SQLConf, + buildKeys: Seq[Expression], + buildNode: LocalNode): HashedRelation = { + + val isUnsafeMode = + conf.codegenEnabled && + conf.unsafeEnabled && + UnsafeProjection.canSupport(buildKeys) + + val buildSideKeyGenerator = + if (isUnsafeMode) { + UnsafeProjection.create(buildKeys, buildNode.output) + } else { + new InterpretedMutableProjection(buildKeys, buildNode.output) + } + + buildNode.prepare() + buildNode.open() + val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator) + buildNode.close() + + hashedRelation + } + + /** + * Test inner hash join with varying degrees of matches. + */ + private def testJoin( + unsafeAndCodegen: Boolean, + buildSide: BuildSide): Unit = { + val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" + val testNamePrefix = s"$simpleOrUnsafe / $buildSide" + val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray + val conf = new SQLConf + conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) + conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) + + // Actual test body + def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = { + val rightInputMap = rightInput.toMap + val leftNode = new DummyNode(joinNameAttributes, leftInput) + val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val makeBinaryHashJoinNode = (node1: LocalNode, node2: LocalNode) => { + val binaryHashJoinNode = + BinaryHashJoinNode(conf, Seq('id1), Seq('id2), buildSide, node1, node2) + resolveExpressions(binaryHashJoinNode) + } + val makeBroadcastJoinNode = (node1: LocalNode, node2: LocalNode) => { + val leftKeys = Seq('id1.attr) + val rightKeys = Seq('id2.attr) + // Figure out the build side and stream side. + val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match { + case BuildLeft => (node1, leftKeys, node2, rightKeys) + case BuildRight => (node2, rightKeys, node1, leftKeys) + } + // Resolve the expressions of the build side and then create a HashedRelation. + val resolvedBuildNode = resolveExpressions(buildNode) + val resolvedBuildKeys = resolveExpressions(buildKeys, resolvedBuildNode) + val hashedRelation = buildHashedRelation(conf, resolvedBuildKeys, resolvedBuildNode) + val broadcastHashedRelation = mock(classOf[TorrentBroadcast[HashedRelation]]) + when(broadcastHashedRelation.value).thenReturn(hashedRelation) + + val hashJoinNode = + BroadcastHashJoinNode( + conf, + streamedKeys, + streamedNode, + buildSide, + resolvedBuildNode.output, + broadcastHashedRelation) + resolveExpressions(hashJoinNode) + } + + val expectedOutput = leftInput + .filter { case (k, _) => rightInputMap.contains(k) } + .map { case (k, v) => (k, v, k, rightInputMap(k)) } + + Seq(makeBinaryHashJoinNode, makeBroadcastJoinNode).foreach { makeNode => + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + } + assert(actualOutput === expectedOutput) + } + } + + test(s"$testNamePrefix: empty") { + runTest(Array.empty, Array.empty) + runTest(someData, Array.empty) + runTest(Array.empty, someData) + } + + test(s"$testNamePrefix: no matches") { + val someIrrelevantData = (10000 to 100100).map { i => (i, "piper" + i) }.toArray + runTest(someData, Array.empty) + runTest(Array.empty, someData) + runTest(someData, someIrrelevantData) + runTest(someIrrelevantData, someData) + } + + test(s"$testNamePrefix: partial matches") { + val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray + runTest(someData, someOtherData) + runTest(someOtherData, someData) + } + + test(s"$testNamePrefix: full matches") { + val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) }.toArray + runTest(someData, someSuperRelevantData) + runTest(someSuperRelevantData, someData) + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala new file mode 100644 index 0000000000000..c0ad2021b204a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala @@ -0,0 +1,37 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + + +class IntersectNodeSuite extends LocalNodeTest { + + test("basic") { + val n = 100 + val leftData = (1 to n).filter { i => i % 2 == 0 }.map { i => (i, i) }.toArray + val rightData = (1 to n).filter { i => i % 3 == 0 }.map { i => (i, i) }.toArray + val leftNode = new DummyNode(kvIntAttributes, leftData) + val rightNode = new DummyNode(kvIntAttributes, rightData) + val intersectNode = new IntersectNode(conf, leftNode, rightNode) + val expectedOutput = leftData.intersect(rightData) + val actualOutput = intersectNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala index 523c02f4a6014..fb790636a3689 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala @@ -17,23 +17,25 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext -class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { +class LimitNodeSuite extends LocalNodeTest { - test("basic") { - checkAnswer( - testData, - node => LimitNode(10, node), - testData.limit(10).collect() - ) + private def testLimit(inputData: Array[(Int, Int)] = Array.empty, limit: Int = 10): Unit = { + val inputNode = new DummyNode(kvIntAttributes, inputData) + val limitNode = new LimitNode(conf, limit, inputNode) + val expectedOutput = inputData.take(limit) + val actualOutput = limitNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - checkAnswer( - emptyTestData, - node => LimitNode(10, node), - emptyTestData.limit(10).collect() - ) + testLimit() } + + test("basic") { + testLimit((1 to 100).map { i => (i, i) }.toArray, 20) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala new file mode 100644 index 0000000000000..0d1ed99eec6cd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala @@ -0,0 +1,73 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + + +class LocalNodeSuite extends LocalNodeTest { + private val data = (1 to 100).map { i => (i, i) }.toArray + + test("basic open, next, fetch, close") { + val node = new DummyNode(kvIntAttributes, data) + assert(!node.isOpen) + node.open() + assert(node.isOpen) + data.foreach { case (k, v) => + assert(node.next()) + // fetch should be idempotent + val fetched = node.fetch() + assert(node.fetch() === fetched) + assert(node.fetch() === fetched) + assert(node.fetch().numFields === 2) + assert(node.fetch().getInt(0) === k) + assert(node.fetch().getInt(1) === v) + } + assert(!node.next()) + node.close() + assert(!node.isOpen) + } + + test("asIterator") { + val node = new DummyNode(kvIntAttributes, data) + val iter = node.asIterator + node.open() + data.foreach { case (k, v) => + // hasNext should be idempotent + assert(iter.hasNext) + assert(iter.hasNext) + val item = iter.next() + assert(item.numFields === 2) + assert(item.getInt(0) === k) + assert(item.getInt(1) === v) + } + intercept[NoSuchElementException] { + iter.next() + } + node.close() + } + + test("collect") { + val node = new DummyNode(kvIntAttributes, data) + node.open() + val collected = node.collect() + assert(collected.size === data.size) + assert(collected.forall(_.size === 2)) + assert(collected.map { case row => (row.getInt(0), row.getInt(0)) } === data) + node.close() + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 95f06081bd0a8..615c417093612 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -17,130 +17,72 @@ package org.apache.spark.sql.execution.local -import scala.util.control.NonFatal - import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference} +import org.apache.spark.sql.types.{IntegerType, StringType} + class LocalNodeTest extends SparkFunSuite { - /** - * Runs the LocalNode and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate - * the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. - */ - protected def checkAnswer( - input: DataFrame, - nodeFunction: LocalNode => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - doCheckAnswer( - input :: Nil, - nodes => nodeFunction(nodes.head), - expectedAnswer, - sortAnswers) - } + protected val conf: SQLConf = new SQLConf + protected val kvIntAttributes = Seq( + AttributeReference("k", IntegerType)(), + AttributeReference("v", IntegerType)()) + protected val joinNameAttributes = Seq( + AttributeReference("id1", IntegerType)(), + AttributeReference("name", StringType)()) + protected val joinNicknameAttributes = Seq( + AttributeReference("id2", IntegerType)(), + AttributeReference("nickname", StringType)()) /** - * Runs the LocalNode and makes sure the answer matches the expected result. - * @param left the left input data to be used. - * @param right the right input data to be used. - * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate - * the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. + * Wrap a function processing two [[LocalNode]]s such that: + * (1) all input rows are automatically converted to unsafe rows + * (2) all output rows are automatically converted back to safe rows */ - protected def checkAnswer2( - left: DataFrame, - right: DataFrame, - nodeFunction: (LocalNode, LocalNode) => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - doCheckAnswer( - left :: right :: Nil, - nodes => nodeFunction(nodes(0), nodes(1)), - expectedAnswer, - sortAnswers) + protected def wrapForUnsafe( + f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { + (left: LocalNode, right: LocalNode) => { + val _left = ConvertToUnsafeNode(conf, left) + val _right = ConvertToUnsafeNode(conf, right) + val r = f(_left, _right) + ConvertToSafeNode(conf, r) + } } /** - * Runs the `LocalNode`s and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts a sequence of input `LocalNode`s and uses them to - * instantiate the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. + * Recursively resolve all expressions in a [[LocalNode]] using the node's attributes. */ - protected def doCheckAnswer( - input: Seq[DataFrame], - nodeFunction: Seq[LocalNode] => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - LocalNodeTest.checkAnswer( - input.map(dataFrameToSeqScanNode), nodeFunction, expectedAnswer, sortAnswers) match { - case Some(errorMessage) => fail(errorMessage) - case None => + protected def resolveExpressions(outputNode: LocalNode): LocalNode = { + outputNode transform { + case node: LocalNode => + val inputMap = node.output.map { a => (a.name, a) }.toMap + node transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } } } - protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = { - new SeqScanNode( - df.queryExecution.sparkPlan.output, - df.queryExecution.toRdd.map(_.copy()).collect()) - } - -} - -/** - * Helper methods for writing tests of individual local physical operators. - */ -object LocalNodeTest { - /** - * Runs the `LocalNode`s and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts the input `LocalNode`s and uses them to - * instantiate the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. + * Resolve all expressions in `expressions` based on the `output` of `localNode`. + * It assumes that all expressions in the `localNode` are resolved. */ - def checkAnswer( - input: Seq[SeqScanNode], - nodeFunction: Seq[LocalNode] => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean): Option[String] = { - - val outputNode = nodeFunction(input) - - val outputResult: Seq[Row] = try { - outputNode.collect() - } catch { - case NonFatal(e) => - val errorMessage = - s""" - | Exception thrown while executing local plan: - | $outputNode - | == Exception == - | $e - | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin - return Some(errorMessage) - } - - SQLTestUtils.compareAnswers(outputResult, expectedAnswer, sortAnswers).map { errorMessage => - s""" - | Results do not match for local plan: - | $outputNode - | $errorMessage - """.stripMargin + protected def resolveExpressions( + expressions: Seq[Expression], + localNode: LocalNode): Seq[Expression] = { + require(localNode.expressions.forall(_.resolved)) + val inputMap = localNode.output.map { a => (a.name, a) }.toMap + expressions.map { expression => + expression.transformUp { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } } } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala new file mode 100644 index 0000000000000..40299d9d5ee37 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -0,0 +1,145 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + + +class NestedLoopJoinNodeSuite extends LocalNodeTest { + + // Test all combinations of the three dimensions: with/out unsafe, build sides, and join types + private val maybeUnsafeAndCodegen = Seq(false, true) + private val buildSides = Seq(BuildLeft, BuildRight) + private val joinTypes = Seq(LeftOuter, RightOuter, FullOuter) + maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => + buildSides.foreach { buildSide => + joinTypes.foreach { joinType => + testJoin(unsafeAndCodegen, buildSide, joinType) + } + } + } + + /** + * Test outer nested loop joins with varying degrees of matches. + */ + private def testJoin( + unsafeAndCodegen: Boolean, + buildSide: BuildSide, + joinType: JoinType): Unit = { + val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" + val testNamePrefix = s"$simpleOrUnsafe / $buildSide / $joinType" + val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray + val conf = new SQLConf + conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) + conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) + + // Actual test body + def runTest( + joinType: JoinType, + leftInput: Array[(Int, String)], + rightInput: Array[(Int, String)]): Unit = { + val leftNode = new DummyNode(joinNameAttributes, leftInput) + val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val cond = 'id1 === 'id2 + val makeNode = (node1: LocalNode, node2: LocalNode) => { + resolveExpressions( + new NestedLoopJoinNode(conf, node1, node2, buildSide, joinType, Some(cond))) + } + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType) + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + } + assert(actualOutput.toSet === expectedOutput.toSet) + } + + test(s"$testNamePrefix: empty") { + runTest(joinType, Array.empty, Array.empty) + } + + test(s"$testNamePrefix: no matches") { + val someIrrelevantData = (10000 to 10100).map { i => (i, "piper" + i) }.toArray + runTest(joinType, someData, Array.empty) + runTest(joinType, Array.empty, someData) + runTest(joinType, someData, someIrrelevantData) + runTest(joinType, someIrrelevantData, someData) + } + + test(s"$testNamePrefix: partial matches") { + val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray + runTest(joinType, someData, someOtherData) + runTest(joinType, someOtherData, someData) + } + + test(s"$testNamePrefix: full matches") { + val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) } + runTest(joinType, someData, someSuperRelevantData) + runTest(joinType, someSuperRelevantData, someData) + } + } + + /** + * Helper method to generate the expected output of a test based on the join type. + */ + private def generateExpectedOutput( + leftInput: Array[(Int, String)], + rightInput: Array[(Int, String)], + joinType: JoinType): Array[(Int, String, Int, String)] = { + joinType match { + case LeftOuter => + val rightInputMap = rightInput.toMap + leftInput.map { case (k, v) => + val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) + val rightValue = rightInputMap.getOrElse(k, null) + (k, v, rightKey, rightValue) + } + + case RightOuter => + val leftInputMap = leftInput.toMap + rightInput.map { case (k, v) => + val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) + val leftValue = leftInputMap.getOrElse(k, null) + (leftKey, leftValue, k, v) + } + + case FullOuter => + val leftInputMap = leftInput.toMap + val rightInputMap = rightInput.toMap + val leftOutput = leftInput.map { case (k, v) => + val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) + val rightValue = rightInputMap.getOrElse(k, null) + (k, v, rightKey, rightValue) + } + val rightOutput = rightInput.map { case (k, v) => + val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) + val leftValue = leftInputMap.getOrElse(k, null) + (leftKey, leftValue, k, v) + } + (leftOutput ++ rightOutput).distinct + + case other => + throw new IllegalArgumentException(s"Join type $other is not applicable") + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala index ffcf092e2c66a..02ecb23d34b2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala @@ -17,28 +17,33 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression} +import org.apache.spark.sql.types.{IntegerType, StringType} -class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { - test("basic") { - val output = testData.queryExecution.sparkPlan.output - val columns = Seq(output(1), output(0)) - checkAnswer( - testData, - node => ProjectNode(columns, node), - testData.select("value", "key").collect() - ) +class ProjectNodeSuite extends LocalNodeTest { + private val pieAttributes = Seq( + AttributeReference("id", IntegerType)(), + AttributeReference("age", IntegerType)(), + AttributeReference("name", StringType)()) + + private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = { + val inputNode = new DummyNode(pieAttributes, inputData) + val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2)) + val projectNode = new ProjectNode(conf, columns, inputNode) + val expectedOutput = inputData.map { case (id, age, name) => (id, name) } + val actualOutput = projectNode.collect().map { case row => + (row.getInt(0), row.getString(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - val output = emptyTestData.queryExecution.sparkPlan.output - val columns = Seq(output(1), output(0)) - checkAnswer( - emptyTestData, - node => ProjectNode(columns, node), - emptyTestData.select("value", "key").collect() - ) + testProject() + } + + test("basic") { + testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala new file mode 100644 index 0000000000000..a3e83bbd51457 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + + +class SampleNodeSuite extends LocalNodeTest { + + private def testSample(withReplacement: Boolean): Unit = { + val seed = 0L + val lowerb = 0.0 + val upperb = 0.3 + val maybeOut = if (withReplacement) "" else "out" + test(s"with$maybeOut replacement") { + val inputData = (1 to 1000).map { i => (i, i) }.toArray + val inputNode = new DummyNode(kvIntAttributes, inputData) + val sampleNode = new SampleNode(conf, lowerb, upperb, withReplacement, seed, inputNode) + val sampler = + if (withReplacement) { + new PoissonSampler[(Int, Int)](upperb - lowerb, useGapSamplingIfPossible = false) + } else { + new BernoulliCellSampler[(Int, Int)](lowerb, upperb) + } + sampler.setSeed(seed) + val expectedOutput = sampler.sample(inputData.iterator).toArray + val actualOutput = sampleNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) + } + } + + testSample(withReplacement = true) + testSample(withReplacement = false) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala new file mode 100644 index 0000000000000..42ebc7bfcaadc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import scala.util.Random + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SortOrder + + +class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { + + private def testTakeOrderedAndProject(desc: Boolean): Unit = { + val limit = 10 + val ascOrDesc = if (desc) "desc" else "asc" + test(ascOrDesc) { + val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) }.toArray + val inputNode = new DummyNode(kvIntAttributes, inputData) + val firstColumn = inputNode.output(0) + val sortDirection = if (desc) Descending else Ascending + val sortOrder = SortOrder(firstColumn, sortDirection) + val takeOrderAndProjectNode = new TakeOrderedAndProjectNode( + conf, limit, Seq(sortOrder), Some(Seq(firstColumn)), inputNode) + val expectedOutput = inputData + .map { case (k, _) => k } + .sortBy { k => k * (if (desc) -1 else 1) } + .take(limit) + val actualOutput = takeOrderAndProjectNode.collect().map { row => row.getInt(0) } + assert(actualOutput === expectedOutput) + } + } + + testTakeOrderedAndProject(desc = false) + testTakeOrderedAndProject(desc = true) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala index 34670287c3e1d..666b0235c061d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala @@ -17,36 +17,39 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext -class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { +class UnionNodeSuite extends LocalNodeTest { - test("basic") { - checkAnswer2( - testData, - testData, - (node1, node2) => UnionNode(Seq(node1, node2)), - testData.unionAll(testData).collect() - ) + private def testUnion(inputData: Seq[Array[(Int, Int)]]): Unit = { + val inputNodes = inputData.map { data => + new DummyNode(kvIntAttributes, data) + } + val unionNode = new UnionNode(conf, inputNodes) + val expectedOutput = inputData.flatten + val actualOutput = unionNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - checkAnswer2( - emptyTestData, - emptyTestData, - (node1, node2) => UnionNode(Seq(node1, node2)), - emptyTestData.unionAll(emptyTestData).collect() - ) + testUnion(Seq(Array.empty)) + testUnion(Seq(Array.empty, Array.empty)) + } + + test("self") { + val data = (1 to 100).map { i => (i, i) }.toArray + testUnion(Seq(data)) + testUnion(Seq(data, data)) + testUnion(Seq(data, data, data)) } - test("complicated union") { - val dfs = Seq(testData, emptyTestData, emptyTestData, testData, testData, emptyTestData, - emptyTestData, emptyTestData, testData, emptyTestData) - doCheckAnswer( - dfs, - nodes => UnionNode(nodes), - dfs.reduce(_.unionAll(_)).collect() - ) + test("basic") { + val zero = Array.empty[(Int, Int)] + val one = (1 to 100).map { i => (i, i) }.toArray + val two = (50 to 150).map { i => (i, i) }.toArray + val three = (800 to 900).map { i => (i, i) }.toArray + testUnion(Seq(zero, one, two, three)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 80006bf077fe8..cdd885ba14203 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -36,7 +36,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ test("LongSQLMetric should not box Long") { - val l = SQLMetrics.createLongMetric(ctx.sparkContext, "long") + val l = SQLMetrics.createLongMetric(sparkContext, "long") val f = () => { l += 1L l.add(1L) @@ -50,7 +50,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("Normal accumulator should do boxing") { // We need this test to make sure BoxingFinder works. - val l = ctx.sparkContext.accumulator(0L) + val l = sparkContext.accumulator(0L) val f = () => { l += 1L } BoxingFinder.getClassReader(f.getClass).foreach { cl => val boxingFinder = new BoxingFinder() @@ -71,19 +71,19 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { df: DataFrame, expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { - val previousExecutionIds = ctx.listener.executionIdToData.keySet + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet df.collect() - ctx.sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = ctx.listener.getExecution(executionId).get.jobs + val jobs = sqlContext.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= expectedNumOfJobs) if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values - val metricValues = ctx.listener.getExecutionMetrics(executionId) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => expectedMetrics.contains(node.id) }.map { node => @@ -93,7 +93,16 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { }.toMap (node.id, node.name -> nodeMetrics) }.toMap - assert(expectedMetrics === actualMetrics) + + assert(expectedMetrics.keySet === actualMetrics.keySet) + for (nodeId <- expectedMetrics.keySet) { + val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId) + val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId) + assert(expectedNodeName === actualNodeName) + for (metricName <- expectedMetricsMap.keySet) { + assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName)) + } + } } else { // TODO Remove this "else" once we fix the race condition that missing the JobStarted event. // Since we cannot track all jobs, the metric values could be wrong and we should not check @@ -474,22 +483,22 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - val previousExecutionIds = ctx.listener.executionIdToData.keySet + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet // Assume the execution plan is // PhysicalRDD(nodeId = 0) person.select('name).write.format("json").save(file.getAbsolutePath) - ctx.sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = ctx.listener.getExecution(executionId).get.jobs + val jobs = sqlContext.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= 1) - val metricValues = ctx.listener.getExecutionMetrics(executionId) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) // Because "save" will create a new DataFrame internally, we cannot get the real metric id. // However, we still can check the value. - assert(metricValues.values.toSeq === Seq(2L)) + assert(metricValues.values.toSeq === Seq("2")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 80d1e88956949..c15aac775096c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -54,9 +54,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { details = "" ) - private def createTaskInfo(taskId: Int, attempt: Int): TaskInfo = new TaskInfo( + private def createTaskInfo(taskId: Int, attemptNumber: Int): TaskInfo = new TaskInfo( taskId = taskId, - attempt = attempt, + attemptNumber = attemptNumber, // The following fields are not used in tests index = 0, launchTime = 0, @@ -74,7 +74,11 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("basic") { - val listener = new SQLListener(ctx) + def checkAnswer(actual: Map[Long, String], expected: Map[Long, Long]): Unit = { + assert(actual === expected.mapValues(_.toString)) + } + + val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame val accumulatorIds = @@ -114,7 +118,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (1L, 0, 0, createTaskMetrics(accumulatorUpdates)) ))) - assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( // (task id, stage id, stage attempt, metrics) @@ -122,7 +126,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2))) ))) - assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 3)) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3)) // Retrying a stage should reset the metrics listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1))) @@ -133,7 +137,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (1L, 0, 1, createTaskMetrics(accumulatorUpdates)) ))) - assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) // Ignore the task end for the first attempt listener.onTaskEnd(SparkListenerTaskEnd( @@ -144,7 +148,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { createTaskInfo(0, 0), createTaskMetrics(accumulatorUpdates.mapValues(_ * 100)))) - assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2)) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2)) // Finish two tasks listener.onTaskEnd(SparkListenerTaskEnd( @@ -162,7 +166,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { createTaskInfo(1, 0), createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) - assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 5)) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 5)) // Summit a new stage listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0))) @@ -173,7 +177,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { (1L, 1, 0, createTaskMetrics(accumulatorUpdates)) ))) - assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 7)) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7)) // Finish two tasks listener.onTaskEnd(SparkListenerTaskEnd( @@ -191,7 +195,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { createTaskInfo(1, 0), createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) - assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 11)) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 11)) assert(executionUIData.runningJobs === Seq(0)) assert(executionUIData.succeededJobs.isEmpty) @@ -208,11 +212,11 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { assert(executionUIData.succeededJobs === Seq(0)) assert(executionUIData.failedJobs.isEmpty) - assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 11)) + checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 11)) } test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { - val listener = new SQLListener(ctx) + val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -241,7 +245,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { - val listener = new SQLListener(ctx) + val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -281,7 +285,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before onJobEnd(JobFailed)") { - val listener = new SQLListener(ctx) + val listener = new SQLListener(sqlContext.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -309,7 +313,24 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { assert(executionUIData.failedJobs === Seq(0)) } - ignore("no memory leak") { + test("SPARK-11126: no memory leak when running non SQL jobs") { + val previousStageNumber = sqlContext.listener.stageIdToStageMetrics.size + sqlContext.sparkContext.parallelize(1 to 10).foreach(i => ()) + sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) + // listener should ignore the non SQL stage + assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber) + + sqlContext.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) + sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) + // listener should save the SQL stage + assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1) + } + +} + +class SQLListenerMemoryLeakSuite extends SparkFunSuite { + + test("no memory leak") { val conf = new SparkConf() .setMaster("local") .setAppName("test") @@ -344,5 +365,4 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { sc.stop() } } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d8c9a08d84c61..d530b1a469ce2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -255,26 +255,26 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("Basic API") { - assert(ctx.read.jdbc( + assert(sqlContext.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) } test("Basic API with FetchSize") { val properties = new Properties properties.setProperty("fetchSize", "2") - assert(ctx.read.jdbc( + assert(sqlContext.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } test("Partitioning via JDBCPartitioningInfo API") { assert( - ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + assert(sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) .collect().length === 3) } @@ -330,9 +330,9 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("test DATE types") { - val rows = ctx.read.jdbc( + val rows = sqlContext.read.jdbc( urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - val cachedRows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val cachedRows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(1).getAs[java.sql.Date](1) === null) @@ -340,8 +340,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("test DATE types in cache") { - val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().registerTempTable("mycached_date") val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) @@ -349,7 +349,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("test types for null value") { - val rows = ctx.read.jdbc( + val rows = sqlContext.read.jdbc( urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() assert((0 to 14).forall(i => rows(0).isNullAt(i))) } @@ -396,7 +396,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext test("Remap types via JdbcDialects") { JdbcDialects.registerDialect(testH2Dialect) - val df = ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty) val rows = df.collect() assert(rows(0).get(0).isInstanceOf[String]) @@ -408,18 +408,23 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") == DB2Dialect) + assert(JdbcDialects.get("jdbc:sqlserver://127.0.0.1/db") == MsSqlServerDialect) + assert(JdbcDialects.get("jdbc:derby:db") == DerbyDialect) assert(JdbcDialects.get("test.invalid") == NoopDialect) } test("quote column names by jdbc dialect") { val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val Derby = JdbcDialects.get("jdbc:derby:db") val columns = Seq("abc", "key") val MySQLColumns = columns.map(MySQL.quoteIdentifier(_)) val PostgresColumns = columns.map(Postgres.quoteIdentifier(_)) + val DerbyColumns = columns.map(Derby.quoteIdentifier(_)) assert(MySQLColumns === Seq("`abc`", "`key`")) assert(PostgresColumns === Seq(""""abc"""", """"key"""")) + assert(DerbyColumns === Seq(""""abc"""", """"key"""")) } test("Dialect unregister") { @@ -450,4 +455,33 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") } + + test("PostgresDialect type mapping") { + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, null) === Some(StringType)) + assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, null) === Some(StringType)) + } + + test("DerbyDialect jdbc type mapping") { + val derbyDialect = JdbcDialects.get("jdbc:derby:db") + assert(derbyDialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") + assert(derbyDialect.getJDBCType(ByteType).map(_.databaseTypeDefinition).get == "SMALLINT") + assert(derbyDialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "BOOLEAN") + } + + test("table exists query by jdbc dialect") { + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val derby = JdbcDialects.get("jdbc:derby:db") + val table = "weblogs" + val defaultQuery = s"SELECT * FROM $table WHERE 1=0" + val limitQuery = s"SELECT 1 FROM $table LIMIT 1" + assert(MySQL.getTableExistsQuery(table) == limitQuery) + assert(Postgres.getTableExistsQuery(table) == limitQuery) + assert(db2.getTableExistsQuery(table) == defaultQuery) + assert(h2.getTableExistsQuery(table) == defaultQuery) + assert(derby.getTableExistsQuery(table) == defaultQuery) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 5dc3a2c07b8c7..e23ee6693133b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -22,13 +22,12 @@ import java.util.Properties import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { +class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null @@ -76,8 +75,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLCon conn1.close() } - private lazy val sc = ctx.sparkContext - private lazy val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) private lazy val arr1x2 = Array[Row](Row.apply("fred", 3)) private lazy val schema2 = StructType( @@ -91,49 +88,50 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLCon StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) - assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) - assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert( + 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } test("CREATE with overwrite") { - val df = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.DROPTEST", properties) - assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(3 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) - assert(1 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(1 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url, "TEST.APPENDTEST", new Properties) df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) - assert(3 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) - assert(2 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) + assert(3 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) + assert(2 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) } test("CREATE then INSERT to truncate") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) - assert(1 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + assert(1 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) intercept[org.apache.spark.SparkException] { @@ -143,14 +141,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLCon test("INSERT to JDBC Datasource") { sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 9bc3f6bcf6fce..6fc9febe49707 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -26,10 +26,8 @@ import org.apache.spark.sql.execution.datasources.DDLException import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils - class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { protected override lazy val sql = caseInsensitiveContext.sql _ - private lazy val sparkContext = caseInsensitiveContext.sparkContext private var path: File = null override def beforeAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index d74d29fb0beb0..af04079ec895a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -19,13 +19,11 @@ package org.apache.spark.sql.sources import org.apache.spark.sql._ - private[sql] abstract class DataSourceTest extends QueryTest { - protected def _sqlContext: SQLContext // We want to test some edge cases. protected lazy val caseInsensitiveContext: SQLContext = { - val ctx = new SQLContext(_sqlContext.sparkContext) + val ctx = new SQLContext(sqlContext.sparkContext) ctx.setConf(SQLConf.CASE_SENSITIVE, false) ctx } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 084d83f6e9bff..5b70d258d6ce3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql.sources import java.io.File -import org.apache.spark.sql.{SaveMode, AnalysisException, Row} +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils class InsertSuite extends DataSourceTest with SharedSQLContext { protected override lazy val sql = caseInsensitiveContext.sql _ - private lazy val sparkContext = caseInsensitiveContext.sparkContext private var path: File = null override def beforeAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index 79b6e9b45c009..c9791879ec74c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -29,11 +29,11 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val df = ctx.range(100).select($"id", lit(1).as("data")) + val df = sqlContext.range(100).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - ctx.read.load(path.getCanonicalPath), + sqlContext.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) @@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val base = ctx.range(100) + val base = sqlContext.range(100) val df = base.unionAll(base).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - ctx.read.load(path.getCanonicalPath), + sqlContext.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index f18546b4c2d9b..10d261368993d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -28,7 +28,6 @@ import org.apache.spark.util.Utils class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { protected override lazy val sql = caseInsensitiveContext.sql _ - private lazy val sparkContext = caseInsensitiveContext.sparkContext private var originalDefaultSource: String = null private var path: File = null private var df: DataFrame = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 3fc02df954e23..520dea7f7dd92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -24,11 +24,11 @@ import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} * A collection of sample data used in SQL tests. */ private[sql] trait SQLTestData { self => - protected def _sqlContext: SQLContext + protected def sqlContext: SQLContext // Helper object to import SQL implicits without a concrete SQLContext private object internalImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self._sqlContext + protected override def _sqlContext: SQLContext = self.sqlContext } import internalImplicits._ @@ -37,21 +37,21 @@ private[sql] trait SQLTestData { self => // Note: all test data should be lazy because the SQLContext is not set up yet. protected lazy val emptyTestData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( Seq.empty[Int].map(i => TestData(i, i.toString))).toDF() df.registerTempTable("emptyTestData") df } protected lazy val testData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() df.registerTempTable("testData") df } protected lazy val testData2: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( TestData2(1, 1) :: TestData2(1, 2) :: TestData2(2, 1) :: @@ -63,7 +63,7 @@ private[sql] trait SQLTestData { self => } protected lazy val testData3: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( TestData3(1, None) :: TestData3(2, Some(2)) :: Nil).toDF() df.registerTempTable("testData3") @@ -71,14 +71,14 @@ private[sql] trait SQLTestData { self => } protected lazy val negativeData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() df.registerTempTable("negativeData") df } protected lazy val largeAndSmallInts: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( LargeAndSmallInts(2147483644, 1) :: LargeAndSmallInts(1, 2) :: LargeAndSmallInts(2147483645, 1) :: @@ -90,7 +90,7 @@ private[sql] trait SQLTestData { self => } protected lazy val decimalData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( DecimalData(1, 1) :: DecimalData(1, 2) :: DecimalData(2, 1) :: @@ -102,7 +102,7 @@ private[sql] trait SQLTestData { self => } protected lazy val binaryData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( BinaryData("12".getBytes, 1) :: BinaryData("22".getBytes, 5) :: BinaryData("122".getBytes, 3) :: @@ -113,7 +113,7 @@ private[sql] trait SQLTestData { self => } protected lazy val upperCaseData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( UpperCaseData(1, "A") :: UpperCaseData(2, "B") :: UpperCaseData(3, "C") :: @@ -125,7 +125,7 @@ private[sql] trait SQLTestData { self => } protected lazy val lowerCaseData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: @@ -135,7 +135,7 @@ private[sql] trait SQLTestData { self => } protected lazy val arrayData: RDD[ArrayData] = { - val rdd = _sqlContext.sparkContext.parallelize( + val rdd = sqlContext.sparkContext.parallelize( ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) rdd.toDF().registerTempTable("arrayData") @@ -143,7 +143,7 @@ private[sql] trait SQLTestData { self => } protected lazy val mapData: RDD[MapData] = { - val rdd = _sqlContext.sparkContext.parallelize( + val rdd = sqlContext.sparkContext.parallelize( MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: @@ -154,13 +154,13 @@ private[sql] trait SQLTestData { self => } protected lazy val repeatedData: RDD[StringData] = { - val rdd = _sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) + val rdd = sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) rdd.toDF().registerTempTable("repeatedData") rdd } protected lazy val nullableRepeatedData: RDD[StringData] = { - val rdd = _sqlContext.sparkContext.parallelize( + val rdd = sqlContext.sparkContext.parallelize( List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) rdd.toDF().registerTempTable("nullableRepeatedData") @@ -168,7 +168,7 @@ private[sql] trait SQLTestData { self => } protected lazy val nullInts: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( NullInts(1) :: NullInts(2) :: NullInts(3) :: @@ -178,7 +178,7 @@ private[sql] trait SQLTestData { self => } protected lazy val allNulls: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( NullInts(null) :: NullInts(null) :: NullInts(null) :: @@ -188,7 +188,7 @@ private[sql] trait SQLTestData { self => } protected lazy val nullStrings: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( NullStrings(1, "abc") :: NullStrings(2, "ABC") :: NullStrings(3, null) :: Nil).toDF() @@ -197,13 +197,13 @@ private[sql] trait SQLTestData { self => } protected lazy val tableName: DataFrame = { - val df = _sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() + val df = sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() df.registerTempTable("tableName") df } protected lazy val unparsedStrings: RDD[String] = { - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( "1, A1, true, null" :: "2, B2, false, null" :: "3, C3, true, null" :: @@ -212,13 +212,13 @@ private[sql] trait SQLTestData { self => // An RDD with 4 elements and 8 partitions protected lazy val withEmptyParts: RDD[IntField] = { - val rdd = _sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) + val rdd = sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) rdd.toDF().registerTempTable("withEmptyParts") rdd } protected lazy val person: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( Person(0, "mike", 30) :: Person(1, "jim", 20) :: Nil).toDF() df.registerTempTable("person") @@ -226,7 +226,7 @@ private[sql] trait SQLTestData { self => } protected lazy val salary: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( Salary(0, 2000.0) :: Salary(1, 1000.0) :: Nil).toDF() df.registerTempTable("salary") @@ -234,7 +234,7 @@ private[sql] trait SQLTestData { self => } protected lazy val complexData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: Nil).toDF() @@ -246,7 +246,7 @@ private[sql] trait SQLTestData { self => * Initialize all test data such that all temp tables are properly registered. */ def loadTestData(): Unit = { - assert(_sqlContext != null, "attempted to initialize test data before SQLContext.") + assert(sqlContext != null, "attempted to initialize test data before SQLContext.") emptyTestData testData testData2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index dc08306ad9cb4..9214569f18e93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row, SQLContext, SQLImplicits} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.util.Utils @@ -47,13 +47,13 @@ private[sql] trait SQLTestUtils with BeforeAndAfterAll with SQLTestData { self => - protected def _sqlContext: SQLContext + protected def sparkContext = sqlContext.sparkContext // Whether to materialize all test data before the first test is run private var loadTestDataBeforeTests = false // Shorthand for running a query using our SQLContext - protected lazy val sql = _sqlContext.sql _ + protected lazy val sql = sqlContext.sql _ /** * A helper object for importing SQL implicits. @@ -63,7 +63,14 @@ private[sql] trait SQLTestUtils * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self._sqlContext + protected override def _sqlContext: SQLContext = self.sqlContext + + // This must live here to preserve binary compatibility with Spark < 1.5. + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args: _*)) + } + } } /** @@ -84,8 +91,8 @@ private[sql] trait SQLTestUtils /** * The Hadoop configuration used by the active [[SQLContext]]. */ - protected def configuration: Configuration = { - _sqlContext.sparkContext.hadoopConfiguration + protected def hadoopConfiguration: Configuration = { + sparkContext.hadoopConfiguration } /** @@ -96,12 +103,12 @@ private[sql] trait SQLTestUtils */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(_sqlContext.conf.getConfString(key)).toOption) - (keys, values).zipped.foreach(_sqlContext.conf.setConfString) + val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) + (keys, values).zipped.foreach(sqlContext.conf.setConfString) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => _sqlContext.conf.setConfString(key, value) - case (key, None) => _sqlContext.conf.unsetConf(key) + case (key, Some(value)) => sqlContext.conf.setConfString(key, value) + case (key, None) => sqlContext.conf.unsetConf(key) } } } @@ -133,7 +140,7 @@ private[sql] trait SQLTestUtils * Drops temporary table `tableName` after calling `f`. */ protected def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(_sqlContext.dropTempTable) + try f finally tableNames.foreach(sqlContext.dropTempTable) } /** @@ -142,7 +149,7 @@ private[sql] trait SQLTestUtils protected def withTable(tableNames: String*)(f: => Unit): Unit = { try f finally { tableNames.foreach { name => - _sqlContext.sql(s"DROP TABLE IF EXISTS $name") + sqlContext.sql(s"DROP TABLE IF EXISTS $name") } } } @@ -155,12 +162,12 @@ private[sql] trait SQLTestUtils val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" try { - _sqlContext.sql(s"CREATE DATABASE $dbName") + sqlContext.sql(s"CREATE DATABASE $dbName") } catch { case cause: Throwable => fail("Failed to create temporary database", cause) } - try f(dbName) finally _sqlContext.sql(s"DROP DATABASE $dbName CASCADE") + try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE") } /** @@ -168,8 +175,8 @@ private[sql] trait SQLTestUtils * `f` returns. */ protected def activateDatabase(db: String)(f: => Unit): Unit = { - _sqlContext.sql(s"USE $db") - try f finally _sqlContext.sql(s"USE default") + sqlContext.sql(s"USE $db") + try f finally sqlContext.sql(s"USE default") } /** @@ -177,7 +184,7 @@ private[sql] trait SQLTestUtils * way to construct [[DataFrame]] directly out of local data without relying on implicits. */ protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - DataFrame(_sqlContext, plan) + DataFrame(sqlContext, plan) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index d23c6a0732669..963d10eed62ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.test -import org.apache.spark.sql.{ColumnName, SQLContext} +import org.apache.spark.sql.SQLContext /** @@ -36,9 +36,7 @@ trait SharedSQLContext extends SQLTestUtils { /** * The [[TestSQLContext]] to use for all tests in this suite. */ - protected def ctx: TestSQLContext = _ctx - protected def sqlContext: TestSQLContext = _ctx - protected override def _sqlContext: SQLContext = _ctx + protected def sqlContext: SQLContext = _ctx /** * Initialize the [[TestSQLContext]]. @@ -64,15 +62,4 @@ trait SharedSQLContext extends SQLTestUtils { super.afterAll() } } - - /** - * Converts $"col name" into an [[Column]]. - * @since 1.3.0 - */ - // This must be duplicated here to preserve binary compatibility with Spark < 1.5. - implicit class StringToColumn(val sc: StringContext) { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args: _*)) - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 92ef2f7d74ba1..c89a1516503e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -31,13 +31,17 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel new SparkConf().set("spark.sql.testkey", "true"))) } - // Use fewer partitions to speed up testing - protected[sql] override def createSession(): SQLSession = new this.SQLSession() + protected[sql] override lazy val conf: SQLConf = new SQLConf { - /** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */ - protected[sql] class SQLSession extends super.SQLSession { - protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) + clear() + + override def clear(): Unit = { + super.clear() + + // Make sure we start with the default test configs even after clear + TestSQLContext.overrideConfs.map { + case (key, value) => setConfString(key, value) + } } } @@ -47,6 +51,17 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel } private object testData extends SQLTestData { - protected override def _sqlContext: SQLContext = self + protected override def sqlContext: SQLContext = self } } + +private[sql] object TestSQLContext { + + /** + * A map used to store all confs that need to be overridden in sql/core unit tests. + */ + val overrideConfs: Map[String, String] = + Map( + // Fewer shuffle partitions to speed up testing. + SQLConf.SHUFFLE_PARTITIONS.key -> "5") +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala new file mode 100644 index 0000000000000..eb056cd519717 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.SparkException +import org.apache.spark.sql.{functions, QueryTest} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.test.SharedSQLContext + +import scala.collection.mutable.ArrayBuffer + +class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + import functions._ + + test("execute callback functions when a DataFrame action finished successfully") { + val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)] + val listener = new QueryExecutionListener { + // Only test successful case here, so no need to implement `onFailure` + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + metrics += ((funcName, qe, duration)) + } + } + sqlContext.listenerManager.register(listener) + + val df = Seq(1 -> "a").toDF("i", "j") + df.select("i").collect() + df.filter($"i" > 0).count() + + assert(metrics.length == 2) + + assert(metrics(0)._1 == "collect") + assert(metrics(0)._2.analyzed.isInstanceOf[Project]) + assert(metrics(0)._3 > 0) + + assert(metrics(1)._1 == "count") + assert(metrics(1)._2.analyzed.isInstanceOf[Aggregate]) + assert(metrics(1)._3 > 0) + } + + test("execute callback functions when a DataFrame action failed") { + val metrics = ArrayBuffer.empty[(String, QueryExecution, Exception)] + val listener = new QueryExecutionListener { + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + metrics += ((funcName, qe, exception)) + } + + // Only test failed case here, so no need to implement `onSuccess` + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {} + } + sqlContext.listenerManager.register(listener) + + val errorUdf = udf[Int, Int] { _ => throw new RuntimeException("udf error") } + val df = sparkContext.makeRDD(Seq(1 -> "a")).toDF("i", "j") + + // Ignore the log when we are expecting an exception. + sparkContext.setLogLevel("FATAL") + val e = intercept[SparkException](df.select(errorUdf($"i")).collect()) + + assert(metrics.length == 1) + assert(metrics(0)._1 == "collect") + assert(metrics(0)._2.analyzed.isInstanceOf[Project]) + assert(metrics(0)._3.getMessage == e.getMessage) + } +} diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 3566c87dd248c..b5b2143292a69 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -93,6 +93,10 @@ ${project.version} test + + org.apache.spark + spark-test-tags_${scala.binary.version} + target/scala-${scala.binary.version}/classes diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index dd9fef9206d0b..a4fd0c3ce9702 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -55,7 +55,6 @@ object HiveThriftServer2 extends Logging { @DeveloperApi def startWithContext(sqlContext: HiveContext): Unit = { val server = new HiveThriftServer2(sqlContext) - sqlContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) server.init(sqlContext.hiveconf) server.start() listener = new HiveThriftServer2Listener(server, sqlContext.conf) @@ -93,6 +92,12 @@ object HiveThriftServer2 extends Logging { } else { None } + // If application was killed before HiveThriftServer2 start successfully then SparkSubmit + // process can not exit, so check whether if SparkContext was stopped. + if (SparkSQLEnv.sparkContext.stopped.get()) { + logError("SparkContext has stopped even if HiveServer2 has started, so exit") + System.exit(-1) + } } catch { case e: Exception => logError("Error starting HiveThriftServer2", e) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 306f98bcb5344..719b03e1c7c71 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -20,19 +20,15 @@ package org.apache.spark.sql.hive.thriftserver import java.security.PrivilegedExceptionAction import java.sql.{Date, Timestamp} import java.util.concurrent.RejectedExecutionException -import java.util.{Arrays, Map => JMap, UUID} +import java.util.{Arrays, UUID, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => SMap} import scala.util.control.NonFatal -import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hive.service.cli._ -import org.apache.hadoop.hive.ql.metadata.Hive -import org.apache.hadoop.hive.ql.metadata.HiveException -import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.Utils +import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession @@ -40,7 +36,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} +import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow} private[hive] class SparkExecuteStatementOperation( @@ -143,30 +139,15 @@ private[hive] class SparkExecuteStatementOperation( if (!runInBackground) { runInternal() } else { - val parentSessionState = SessionState.get() - val hiveConf = getConfigForOperation() val sparkServiceUGI = Utils.getUGI() - val sessionHive = getCurrentHive() - val currentSqlSession = hiveContext.currentSession // Runnable impl to call runInternal asynchronously, // from a different thread val backgroundOperation = new Runnable() { override def run(): Unit = { - val doAsAction = new PrivilegedExceptionAction[Object]() { - override def run(): Object = { - - // User information is part of the metastore client member in Hive - hiveContext.setSession(currentSqlSession) - // Always use the latest class loader provided by executionHive's state. - val executionHiveClassLoader = - hiveContext.executionHive.state.getConf.getClassLoader - sessionHive.getConf.setClassLoader(executionHiveClassLoader) - parentSessionState.getConf.setClassLoader(executionHiveClassLoader) - - Hive.set(sessionHive) - SessionState.setCurrentSessionState(parentSessionState) + val doAsAction = new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { try { runInternal() } catch { @@ -174,7 +155,6 @@ private[hive] class SparkExecuteStatementOperation( setOperationException(e) log.error("Error running hive query: ", e) } - return null } } @@ -191,7 +171,7 @@ private[hive] class SparkExecuteStatementOperation( try { // This submit blocks if no background threads are available to run this operation val backgroundHandle = - getParentSession().getSessionManager().submitBackgroundOperation(backgroundOperation) + parentSession.getSessionManager().submitBackgroundOperation(backgroundOperation) setBackgroundHandle(backgroundHandle) } catch { case rejected: RejectedExecutionException => @@ -210,6 +190,11 @@ private[hive] class SparkExecuteStatementOperation( statementId = UUID.randomUUID().toString logInfo(s"Running query '$statement' with $statementId") setState(OperationState.RUNNING) + // Always use the latest class loader provided by executionHive's state. + val executionHiveClassLoader = + hiveContext.executionHive.state.getConf.getClassLoader + Thread.currentThread().setContextClassLoader(executionHiveClassLoader) + HiveThriftServer2.listener.onStatementStart( statementId, parentSession.getSessionHandle.getSessionId.toString, @@ -279,43 +264,4 @@ private[hive] class SparkExecuteStatementOperation( } } } - - /** - * If there are query specific settings to overlay, then create a copy of config - * There are two cases we need to clone the session config that's being passed to hive driver - * 1. Async query - - * If the client changes a config setting, that shouldn't reflect in the execution - * already underway - * 2. confOverlay - - * The query specific settings should only be applied to the query config and not session - * @return new configuration - * @throws HiveSQLException - */ - private def getConfigForOperation(): HiveConf = { - var sqlOperationConf = getParentSession().getHiveConf() - if (!getConfOverlay().isEmpty() || runInBackground) { - // clone the partent session config for this query - sqlOperationConf = new HiveConf(sqlOperationConf) - - // apply overlay query specific settings, if any - getConfOverlay().asScala.foreach { case (k, v) => - try { - sqlOperationConf.verifyAndSet(k, v) - } catch { - case e: IllegalArgumentException => - throw new HiveSQLException("Error applying statement specific settings", e) - } - } - } - return sqlOperationConf - } - - private def getCurrentHive(): Hive = { - try { - return Hive.get() - } catch { - case e: HiveException => - throw new HiveSQLException("Failed to get current Hive object", e); - } - } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 92ac0ec3fca29..33aaead3fbf96 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -36,7 +36,7 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: extends SessionManager(hiveServer) with ReflectedCompositeService { - private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) + private lazy val sparkSqlOperationManager = new SparkSQLOperationManager() override def init(hiveConf: HiveConf) { setSuperField(this, "hiveConf", hiveConf) @@ -60,13 +60,15 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: sessionConf: java.util.Map[String, String], withImpersonation: Boolean, delegationToken: String): SessionHandle = { - hiveContext.openSession() val sessionHandle = super.openSession(protocol, username, passwd, ipAddress, sessionConf, withImpersonation, delegationToken) val session = super.getSession(sessionHandle) HiveThriftServer2.listener.onSessionCreated( session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) + val ctx = hiveContext.newSession() + ctx.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) + sparkSqlOperationManager.sessionToContexts += sessionHandle -> ctx sessionHandle } @@ -74,7 +76,6 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) super.closeSession(sessionHandle) sparkSqlOperationManager.sessionToActivePool -= sessionHandle - - hiveContext.detachSession() + sparkSqlOperationManager.sessionToContexts.remove(sessionHandle) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index c8031ed0f3437..476651a559d2c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -30,20 +30,21 @@ import org.apache.spark.sql.hive.thriftserver.{SparkExecuteStatementOperation, R /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. */ -private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) +private[thriftserver] class SparkSQLOperationManager() extends OperationManager with Logging { val handleToOperation = ReflectionUtils .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation") val sessionToActivePool = Map[SessionHandle, String]() + val sessionToContexts = Map[SessionHandle, HiveContext]() override def newExecuteStatementOperation( parentSession: HiveSession, statement: String, confOverlay: JMap[String, String], async: Boolean): ExecuteStatementOperation = synchronized { - + val hiveContext = sessionToContexts(parentSession.getSessionHandle) val runInBackground = async && hiveContext.hiveThriftServerAsync val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground)(hiveContext, sessionToActivePool) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index e59a14ec00d5c..76d1591a235c2 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -96,7 +96,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { buffer += s"${new Timestamp(new Date().getTime)} - $source> $line" // If we haven't found all expected answers and another expected answer comes up... - if (next < expectedAnswers.size && line.startsWith(expectedAnswers(next))) { + if (next < expectedAnswers.size && line.contains(expectedAnswers(next))) { next += 1 // If all expected answers have been found... if (next == expectedAnswers.size) { @@ -159,7 +159,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE hive_test;" -> "OK", "CACHE TABLE hive_test;" - -> "Time taken: ", + -> "", "SELECT COUNT(*) FROM hive_test;" -> "5", "DROP TABLE hive_test;" @@ -180,7 +180,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { "CREATE TABLE hive_test(key INT, val STRING);" -> "OK", "SHOW TABLES;" - -> "Time taken: " + -> "hive_test" ) runCliWithin(2.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( @@ -210,7 +210,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTable;" -> "OK", "INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;" - -> "Time taken:", + -> "", "SELECT count(key) FROM t1;" -> "5", "DROP TABLE t1;" diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index b72249b3bf8c0..ff8ca0150649d 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -21,6 +21,7 @@ import java.io.File import java.net.URL import java.sql.{Date, DriverManager, SQLException, Statement} +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ @@ -204,6 +205,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { import org.apache.spark.sql.SQLConf var defaultV1: String = null var defaultV2: String = null + var data: ArrayBuffer[Int] = null withMultipleConnectionJdbcStatement( // create table @@ -213,10 +215,16 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { "DROP TABLE IF EXISTS test_map", "CREATE TABLE test_map(key INT, value STRING)", s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map", - "CACHE TABLE test_table AS SELECT key FROM test_map ORDER BY key DESC") + "CACHE TABLE test_table AS SELECT key FROM test_map ORDER BY key DESC", + "CREATE DATABASE db1") queries.foreach(statement.execute) + val plan = statement.executeQuery("explain select * from test_table") + plan.next() + plan.next() + assert(plan.getString(1).contains("InMemoryColumnarTableScan")) + val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") val buf1 = new collection.mutable.ArrayBuffer[Int]() while (rs1.next()) { @@ -232,6 +240,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { rs2.close() assert(buf1 === buf2) + + data = buf1 }, // first session, we get the default value of the session status @@ -288,56 +298,51 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { rs2.close() }, - // accessing the cached data in another session + // try to access the cached data in another session { statement => - val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") - val buf1 = new collection.mutable.ArrayBuffer[Int]() - while (rs1.next()) { - buf1 += rs1.getInt(1) + // Cached temporary table can't be accessed by other sessions + intercept[SQLException] { + statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") } - rs1.close() - val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") - val buf2 = new collection.mutable.ArrayBuffer[Int]() - while (rs2.next()) { - buf2 += rs2.getInt(1) + val plan = statement.executeQuery("explain select key from test_map ORDER BY key DESC") + plan.next() + plan.next() + assert(plan.getString(1).contains("InMemoryColumnarTableScan")) + + val rs = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") + val buf = new collection.mutable.ArrayBuffer[Int]() + while (rs.next()) { + buf += rs.getInt(1) } - rs2.close() + rs.close() + assert(buf === data) + }, - assert(buf1 === buf2) - statement.executeQuery("UNCACHE TABLE test_table") + // switch another database + { statement => + statement.execute("USE db1") - // TODO need to figure out how to determine if the data loaded from cache - val rs3 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") - val buf3 = new collection.mutable.ArrayBuffer[Int]() - while (rs3.next()) { - buf3 += rs3.getInt(1) + // there is no test_map table in db1 + intercept[SQLException] { + statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") } - rs3.close() - assert(buf1 === buf3) + statement.execute("CREATE TABLE test_map2(key INT, value STRING)") }, - // accessing the uncached table + // access default database { statement => - // TODO need to figure out how to determine if the data loaded from cache - val rs1 = statement.executeQuery("SELECT key FROM test_table ORDER BY KEY DESC") - val buf1 = new collection.mutable.ArrayBuffer[Int]() - while (rs1.next()) { - buf1 += rs1.getInt(1) + // current database should still be `default` + intercept[SQLException] { + statement.executeQuery("SELECT key FROM test_map2") } - rs1.close() - val rs2 = statement.executeQuery("SELECT key FROM test_map ORDER BY KEY DESC") - val buf2 = new collection.mutable.ArrayBuffer[Int]() - while (rs2.next()) { - buf2 += rs2.getInt(1) - } - rs2.close() - - assert(buf1 === buf2) + statement.execute("USE db1") + // access test_map2 + statement.executeQuery("SELECT key from test_map2") } ) } @@ -431,6 +436,32 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } ) } + + test("Checks Hive version via SET -v") { + withJdbcStatement { statement => + val resultSet = statement.executeQuery("SET -v") + + val conf = mutable.Map.empty[String, String] + while (resultSet.next()) { + conf += resultSet.getString(1) -> resultSet.getString(2) + } + + assert(conf.get("spark.sql.hive.version") === Some("1.2.1")) + } + } + + test("Checks Hive version via SET") { + withJdbcStatement { statement => + val resultSet = statement.executeQuery("SET") + + val conf = mutable.Map.empty[String, String] + while (resultSet.next()) { + conf += resultSet.getString(1) -> resultSet.getString(2) + } + + assert(conf.get("spark.sql.hive.version") === Some("1.2.1")) + } + } } class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index ab309e0a1d36b..eed9e436f9af7 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -25,10 +25,12 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.tags.ExtendedHiveTest /** * Runs the test cases that are included in the hive distribution. */ +@ExtendedHiveTest class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath private lazy val hiveQueryDir = TestHive.getHiveFile( @@ -682,6 +684,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "load_file_with_space_in_the_name", "loadpart1", "louter_join_ppr", + "macro", "mapjoin_distinct", "mapjoin_filter_on_outerjoin", "mapjoin_mapjoin", diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index be1607476e254..d96f3e2b9f62b 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml @@ -58,6 +58,10 @@ spark-sql_${scala.binary.version} ${project.version} + + org.apache.spark + spark-test-tags_${scala.binary.version} + @@ -84,21 +88,11 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - org.seleniumhq.selenium selenium-java test - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContextState.java b/streaming/src/main/java/org/apache/spark/streaming/StreamingContextState.java similarity index 100% rename from streaming/src/main/scala/org/apache/spark/streaming/StreamingContextState.java rename to streaming/src/main/java/org/apache/spark/streaming/StreamingContextState.java diff --git a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js index 4886b68eeaf76..f82323a1cdd94 100644 --- a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js +++ b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js @@ -154,34 +154,40 @@ function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) { var lastClickedBatch = null; var lastTimeout = null; + function isFailedBatch(batchTime) { + return $("#batch-" + batchTime).attr("isFailed") == "true"; + } + // Add points to the line. However, we make it invisible at first. But when the user moves mouse // over a point, it will be displayed with its detail. svg.selectAll(".point") .data(data) .enter().append("circle") - .attr("stroke", "white") // white and opacity = 0 make it invisible - .attr("fill", "white") - .attr("opacity", "0") + .attr("stroke", function(d) { return isFailedBatch(d.x) ? "red" : "white";}) // white and opacity = 0 make it invisible + .attr("fill", function(d) { return isFailedBatch(d.x) ? "red" : "white";}) + .attr("opacity", function(d) { return isFailedBatch(d.x) ? "1" : "0";}) .style("cursor", "pointer") .attr("cx", function(d) { return x(d.x); }) .attr("cy", function(d) { return y(d.y); }) - .attr("r", function(d) { return 3; }) + .attr("r", function(d) { return isFailedBatch(d.x) ? "2" : "0";}) .on('mouseover', function(d) { var tip = formatYValue(d.y) + " " + unitY + " at " + timeFormat[d.x]; showBootstrapTooltip(d3.select(this).node(), tip); // show the point d3.select(this) - .attr("stroke", "steelblue") - .attr("fill", "steelblue") - .attr("opacity", "1"); + .attr("stroke", function(d) { return isFailedBatch(d.x) ? "red" : "steelblue";}) + .attr("fill", function(d) { return isFailedBatch(d.x) ? "red" : "steelblue";}) + .attr("opacity", "1") + .attr("r", "3"); }) .on('mouseout', function() { hideBootstrapTooltip(d3.select(this).node()); // hide the point d3.select(this) - .attr("stroke", "white") - .attr("fill", "white") - .attr("opacity", "0"); + .attr("stroke", function(d) { return isFailedBatch(d.x) ? "red" : "white";}) + .attr("fill", function(d) { return isFailedBatch(d.x) ? "red" : "white";}) + .attr("opacity", function(d) { return isFailedBatch(d.x) ? "1" : "0";}) + .attr("r", function(d) { return isFailedBatch(d.x) ? "2" : "0";}); }) .on("click", function(d) { if (lastTimeout != null) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index cd5d960369c05..8a6050f5227bf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -32,7 +32,7 @@ import org.apache.spark.streaming.scheduler.JobGenerator private[streaming] -class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) +class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) extends Logging with Serializable { val master = ssc.sc.master val framework = ssc.sc.appName @@ -49,6 +49,8 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) // Reload properties for the checkpoint application since user wants to set a reload property // or spark had changed its value and user wants to set it back. val propertiesToReload = List( + "spark.yarn.app.id", + "spark.yarn.app.attemptId", "spark.driver.host", "spark.driver.port", "spark.master", @@ -319,7 +321,7 @@ object CheckpointReader extends Logging { // Try to read the checkpoint files in the order logInfo("Checkpoint files found: " + checkpointFiles.mkString(",")) - val compressionCodec = CompressionCodec.createCodec(conf) + var readError: Exception = null checkpointFiles.foreach(file => { logInfo("Attempting to load checkpoint from file " + file) try { @@ -330,13 +332,15 @@ object CheckpointReader extends Logging { return Some(cp) } catch { case e: Exception => + readError = e logWarning("Error reading checkpoint from file " + file, e) } }) // If none of checkpoint files could be read, then throw exception if (!ignoreReadError) { - throw new SparkException(s"Failed to read checkpoint from directory $checkpointPath") + throw new SparkException( + s"Failed to read checkpoint from directory $checkpointPath", readError) } None } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 40789c66f3991..1b0b7890b3b00 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -38,9 +38,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def start(time: Time) { this.synchronized { - if (zeroTime != null) { - throw new Exception("DStream graph computation already started") - } + require(zeroTime == null, "DStream graph computation already started") zeroTime = time startTime = time outputStreams.foreach(_.initialize(zeroTime)) @@ -68,20 +66,16 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def setBatchDuration(duration: Duration) { this.synchronized { - if (batchDuration != null) { - throw new Exception("Batch duration already set as " + batchDuration + - ". cannot set it again.") - } + require(batchDuration == null, + s"Batch duration already set as $batchDuration. Cannot set it again.") batchDuration = duration } } def remember(duration: Duration) { this.synchronized { - if (rememberDuration != null) { - throw new Exception("Remember duration already set as " + batchDuration + - ". cannot set it again.") - } + require(rememberDuration == null, + s"Remember duration already set as $rememberDuration. Cannot set it again.") rememberDuration = duration } } @@ -117,7 +111,11 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def generateJobs(time: Time): Seq[Job] = { logDebug("Generating jobs for time " + time) val jobs = this.synchronized { - outputStreams.flatMap(outputStream => outputStream.generateJob(time)) + outputStreams.flatMap { outputStream => + val jobOption = outputStream.generateJob(time) + jobOption.foreach(_.setCallSite(outputStream.creationSite)) + jobOption + } } logDebug("Generated " + jobs.length + " jobs for time " + time) jobs diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index b496d1f341a0b..051f53de64cd5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{CallSite, ShutdownHookManager, Utils} +import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils, Utils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -200,6 +200,8 @@ class StreamingContext private[streaming] ( private val startSite = new AtomicReference[CallSite](null) + private[streaming] def getStartSite(): CallSite = startSite.get() + private var shutdownHookRef: AnyRef = _ conf.getOption("spark.streaming.checkpoint.directory").foreach(checkpoint) @@ -562,6 +564,13 @@ class StreamingContext private[streaming] ( ) } } + + if (Utils.isDynamicAllocationEnabled(sc.conf)) { + logWarning("Dynamic Allocation is enabled for this application. " + + "Enabling Dynamic allocation for Spark Streaming applications can cause data loss if " + + "Write Ahead Log is not enabled for non-replayable sources like Flume. " + + "See the programming guide for details on how to enable the Write Ahead Log") + } } /** @@ -588,12 +597,20 @@ class StreamingContext private[streaming] ( state match { case INITIALIZED => startSite.set(DStream.getCreationSite()) - sparkContext.setCallSite(startSite.get) StreamingContext.ACTIVATION_LOCK.synchronized { StreamingContext.assertNoOtherContextIsActive() try { validate() - scheduler.start() + + // Start the streaming scheduler in a new thread, so that thread local properties + // like call sites and job groups can be reset without affecting those of the + // current thread. + ThreadUtils.runInNewThread("streaming-start") { + sparkContext.setCallSite(startSite.get) + sparkContext.clearJobGroup() + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + scheduler.start() + } state = StreamingContextState.ACTIVE } catch { case NonFatal(e) => @@ -618,6 +635,7 @@ class StreamingContext private[streaming] ( } } + /** * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. @@ -676,32 +694,39 @@ class StreamingContext private[streaming] ( * @param stopGracefully if true, stops gracefully by waiting for the processing of all * received data to be completed */ - def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = synchronized { - try { - state match { - case INITIALIZED => - logWarning("StreamingContext has not been started yet") - case STOPPED => - logWarning("StreamingContext has already been stopped") - case ACTIVE => - scheduler.stop(stopGracefully) - // Removing the streamingSource to de-register the metrics on stop() - env.metricsSystem.removeSource(streamingSource) - uiTab.foreach(_.detach()) - StreamingContext.setActiveContext(null) - waiter.notifyStop() - if (shutdownHookRef != null) { - ShutdownHookManager.removeShutdownHook(shutdownHookRef) - } - logInfo("StreamingContext stopped successfully") + def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { + var shutdownHookRefToRemove: AnyRef = null + synchronized { + try { + state match { + case INITIALIZED => + logWarning("StreamingContext has not been started yet") + case STOPPED => + logWarning("StreamingContext has already been stopped") + case ACTIVE => + scheduler.stop(stopGracefully) + // Removing the streamingSource to de-register the metrics on stop() + env.metricsSystem.removeSource(streamingSource) + uiTab.foreach(_.detach()) + StreamingContext.setActiveContext(null) + waiter.notifyStop() + if (shutdownHookRef != null) { + shutdownHookRefToRemove = shutdownHookRef + shutdownHookRef = null + } + logInfo("StreamingContext stopped successfully") + } + } finally { + // The state should always be Stopped after calling `stop()`, even if we haven't started yet + state = STOPPED } - // Even if we have already stopped, we still need to attempt to stop the SparkContext because - // a user might stop(stopSparkContext = false) and then call stop(stopSparkContext = true). - if (stopSparkContext) sc.stop() - } finally { - // The state should always be Stopped after calling `stop()`, even if we haven't started yet - state = STOPPED } + if (shutdownHookRefToRemove != null) { + ShutdownHookManager.removeShutdownHook(shutdownHookRefToRemove) + } + // Even if we have already stopped, we still need to attempt to stop the SparkContext because + // a user might stop(stopSparkContext = false) and then call stop(stopSparkContext = true). + if (stopSparkContext) sc.stop() } private def stopOnShutdown(): Unit = { @@ -735,7 +760,7 @@ object StreamingContext extends Logging { throw new IllegalStateException( "Only one StreamingContext may be started in this JVM. " + "Currently running StreamingContext was started at" + - activeContext.get.startSite.get.longForm) + activeContext.get.getStartSite().longForm) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 2c373640d2fd9..dfc569451df86 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -170,7 +170,7 @@ private[python] object PythonDStream { */ private[python] abstract class PythonDStream( parent: DStream[_], - @transient pfunc: PythonTransformFunction) + pfunc: PythonTransformFunction) extends DStream[Array[Byte]] (parent.ssc) { val func = new TransformFunction(pfunc) @@ -187,7 +187,7 @@ private[python] abstract class PythonDStream( */ private[python] class PythonTransformedDStream ( parent: DStream[_], - @transient pfunc: PythonTransformFunction) + pfunc: PythonTransformFunction) extends PythonDStream(parent, pfunc) { override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { @@ -206,7 +206,7 @@ private[python] class PythonTransformedDStream ( private[python] class PythonTransformed2DStream( parent: DStream[_], parent2: DStream[_], - @transient pfunc: PythonTransformFunction) + pfunc: PythonTransformFunction) extends DStream[Array[Byte]] (parent.ssc) { val func = new TransformFunction(pfunc) @@ -230,7 +230,7 @@ private[python] class PythonTransformed2DStream( */ private[python] class PythonStateDStream( parent: DStream[Array[Byte]], - @transient reduceFunc: PythonTransformFunction) + reduceFunc: PythonTransformFunction) extends PythonDStream(parent, reduceFunc) { super.persist(StorageLevel.MEMORY_ONLY) @@ -252,8 +252,8 @@ private[python] class PythonStateDStream( */ private[python] class PythonReducedWindowedDStream( parent: DStream[Array[Byte]], - @transient preduceFunc: PythonTransformFunction, - @transient pinvReduceFunc: PythonTransformFunction, + preduceFunc: PythonTransformFunction, + @transient private val pinvReduceFunc: PythonTransformFunction, _windowDuration: Duration, _slideDuration: Duration) extends PythonDStream(parent, preduceFunc) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala index f396c347581ce..4eb92dd8b1053 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala @@ -17,9 +17,10 @@ package org.apache.spark.streaming.dstream +import scala.reflect.ClassTag + import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Time, StreamingContext} -import scala.reflect.ClassTag /** * An input stream that always returns the same RDD on each timestep. Useful for testing. @@ -27,6 +28,9 @@ import scala.reflect.ClassTag class ConstantInputDStream[T: ClassTag](ssc_ : StreamingContext, rdd: RDD[T]) extends InputDStream[T](ssc_) { + require(rdd != null, + "parameter rdd null is illegal, which will lead to NPE in the following transformation") + override def start() {} override def stop() {} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index c358f5b5bd70b..40208a64861fb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -70,7 +70,7 @@ import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Uti */ private[streaming] class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, directory: String, filter: Path => Boolean = FileInputDStream.defaultFilter, newFilesOnly: Boolean = true, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index a6c4cd220e42f..95994c983c0cc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -39,7 +39,7 @@ import org.apache.spark.util.Utils * * @param ssc_ Streaming context that will execute this input stream */ -abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) +abstract class InputDStream[T: ClassTag] (ssc_ : StreamingContext) extends DStream[T](ssc_) { private[streaming] var lastValidTime: Time = null diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala index 186e1bf03a944..002aac9f43617 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala @@ -23,7 +23,7 @@ import org.apache.spark.streaming.receiver.Receiver private[streaming] class PluggableInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, receiver: Receiver[T]) extends ReceiverInputDStream[T](ssc_) { def getReceiver(): Receiver[T] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index a2f5d82a79bd3..cd073646370d0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import java.io.{NotSerializableException, ObjectOutputStream} +import java.io.{NotSerializableException, ObjectInputStream, ObjectOutputStream} import scala.collection.mutable.{ArrayBuffer, Queue} import scala.reflect.ClassTag @@ -27,7 +27,7 @@ import org.apache.spark.streaming.{Time, StreamingContext} private[streaming] class QueueInputDStream[T: ClassTag]( - @transient ssc: StreamingContext, + ssc: StreamingContext, val queue: Queue[RDD[T]], oneAtATime: Boolean, defaultRDD: RDD[T] @@ -37,8 +37,13 @@ class QueueInputDStream[T: ClassTag]( override def stop() { } + private def readObject(in: ObjectInputStream): Unit = { + throw new NotSerializableException("queueStream doesn't support checkpointing. " + + "Please don't use queueStream when checkpointing is enabled.") + } + private def writeObject(oos: ObjectOutputStream): Unit = { - throw new NotSerializableException("queueStream doesn't support checkpointing") + logWarning("queueStream doesn't support checkpointing") } override def compute(validTime: Time): Option[RDD[T]] = { @@ -52,12 +57,12 @@ class QueueInputDStream[T: ClassTag]( if (oneAtATime) { Some(buffer.head) } else { - Some(new UnionRDD(ssc.sc, buffer.toSeq)) + Some(new UnionRDD(context.sc, buffer.toSeq)) } } else if (defaultRDD != null) { Some(defaultRDD) } else { - None + Some(ssc.sparkContext.emptyRDD) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala index e2925b9e03ec3..5a9eda7c12776 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala @@ -39,7 +39,7 @@ import org.apache.spark.streaming.receiver.Receiver */ private[streaming] class RawInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, host: String, port: Int, storageLevel: StorageLevel diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 6c139f32da31d..87c20afd5c13c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -38,7 +38,7 @@ import org.apache.spark.streaming.{StreamingContext, Time} * @param ssc_ Streaming context that will execute this input stream * @tparam T Class type of the object of this stream */ -abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) +abstract class ReceiverInputDStream[T: ClassTag](ssc_ : StreamingContext) extends InputDStream[T](ssc_) { /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 5ce5b7aae6e69..de84e0c9a498d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -32,7 +32,7 @@ import org.apache.spark.streaming.receiver.Receiver private[streaming] class SocketInputDStream[T: ClassTag]( - @transient ssc_ : StreamingContext, + ssc_ : StreamingContext, host: String, port: Int, bytesToObjects: InputStream => Iterator[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala index 5d46ca0715ffd..5eabdf63dc8d7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -17,10 +17,12 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.rdd.{PairRDDFunctions, RDD} -import org.apache.spark.streaming.{Duration, Time} import scala.reflect.ClassTag +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class TransformedDStream[U: ClassTag] ( parents: Seq[DStream[_]], @@ -37,7 +39,16 @@ class TransformedDStream[U: ClassTag] ( override def slideDuration: Duration = parents.head.slideDuration override def compute(validTime: Time): Option[RDD[U]] = { - val parentRDDs = parents.map(_.getOrCompute(validTime).orNull).toSeq - Some(transformFunc(parentRDDs, validTime)) + val parentRDDs = parents.map { parent => parent.getOrCompute(validTime).getOrElse( + // Guard out against parent DStream that return None instead of Some(rdd) to avoid NPE + throw new SparkException(s"Couldn't generate RDD from parent at time $validTime")) + } + val transformedRDD = transformFunc(parentRDDs, validTime) + if (transformedRDD == null) { + throw new SparkException("Transform function must not return null. " + + "Return SparkContext.emptyRDD() instead to represent no element " + + "as the result of transformation.") + } + Some(transformedRDD) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala index 9405dbaa12329..d73ffdfd84d2d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -17,13 +17,14 @@ package org.apache.spark.streaming.dstream +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.spark.SparkException import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.rdd.RDD import org.apache.spark.rdd.UnionRDD -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag - private[streaming] class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) extends DStream[T](parents.head.ssc) { @@ -41,8 +42,8 @@ class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) val rdds = new ArrayBuffer[RDD[T]]() parents.map(_.getOrCompute(validTime)).foreach { case Some(rdd) => rdds += rdd - case None => throw new Exception("Could not generate RDD from a parent for unifying at time " - + validTime) + case None => throw new SparkException("Could not generate RDD from a parent for unifying at" + + s" time $validTime") } if (rdds.size > 0) { Some(new UnionRDD(ssc.sc, rdds)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index e081ffe46f502..f811784b25c82 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -61,7 +61,7 @@ class WriteAheadLogBackedBlockRDDPartition( * * * @param sc SparkContext - * @param blockIds Ids of the blocks that contains this RDD's data + * @param _blockIds Ids of the blocks that contains this RDD's data * @param walRecordHandles Record handles in write ahead logs that contain this RDD's data * @param isBlockIdValid Whether the block Ids are valid (i.e., the blocks are present in the Spark * executors). If not, then block lookups by the block ids will be skipped. @@ -73,23 +73,23 @@ class WriteAheadLogBackedBlockRDDPartition( */ private[streaming] class WriteAheadLogBackedBlockRDD[T: ClassTag]( - @transient sc: SparkContext, - @transient blockIds: Array[BlockId], + sc: SparkContext, + @transient private val _blockIds: Array[BlockId], @transient val walRecordHandles: Array[WriteAheadLogRecordHandle], - @transient isBlockIdValid: Array[Boolean] = Array.empty, + @transient private val isBlockIdValid: Array[Boolean] = Array.empty, storeInBlockManager: Boolean = false, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER) - extends BlockRDD[T](sc, blockIds) { + extends BlockRDD[T](sc, _blockIds) { require( - blockIds.length == walRecordHandles.length, - s"Number of block Ids (${blockIds.length}) must be " + + _blockIds.length == walRecordHandles.length, + s"Number of block Ids (${_blockIds.length}) must be " + s" same as number of WAL record handles (${walRecordHandles.length})") require( - isBlockIdValid.isEmpty || isBlockIdValid.length == blockIds.length, + isBlockIdValid.isEmpty || isBlockIdValid.length == _blockIds.length, s"Number of elements in isBlockIdValid (${isBlockIdValid.length}) must be " + - s" same as number of block Ids (${blockIds.length})") + s" same as number of block Ids (${_blockIds.length})") // Hadoop configuration is not serializable, so broadcast it as a serializable. @transient private val hadoopConfig = sc.hadoopConfiguration @@ -99,9 +99,9 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { assertValid() - Array.tabulate(blockIds.length) { i => + Array.tabulate(_blockIds.length) { i => val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i) - new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), isValid, walRecordHandles(i)) + new WriteAheadLogBackedBlockRDDPartition(i, _blockIds(i), isValid, walRecordHandles(i)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index 9922b6bc1201b..436eb0a566141 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -29,6 +29,7 @@ import org.apache.spark.streaming.Time * the streaming scheduler queue * @param processingStartTime Clock time of when the first job of this batch started processing * @param processingEndTime Clock time of when the last job of this batch finished processing + * @param outputOperationInfos The output operations in this batch */ @DeveloperApi case class BatchInfo( @@ -36,7 +37,8 @@ case class BatchInfo( streamIdToInputInfo: Map[Int, StreamInputInfo], submissionTime: Long, processingStartTime: Option[Long], - processingEndTime: Option[Long] + processingEndTime: Option[Long], + outputOperationInfos: Map[Int, OutputOperationInfo] ) { @deprecated("Use streamIdToInputInfo instead", "1.5.0") @@ -67,4 +69,5 @@ case class BatchInfo( * The number of recorders received by the receivers in this batch. */ def numRecords: Long = streamIdToInputInfo.values.map(_.numRecords).sum + } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index 3c481bf3491f9..ab1b3565fcc19 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -17,8 +17,10 @@ package org.apache.spark.streaming.scheduler +import scala.util.{Failure, Try} + import org.apache.spark.streaming.Time -import scala.util.Try +import org.apache.spark.util.{Utils, CallSite} /** * Class representing a Spark computation. It may contain multiple Spark jobs. @@ -29,6 +31,9 @@ class Job(val time: Time, func: () => _) { private var _outputOpId: Int = _ private var isSet = false private var _result: Try[_] = null + private var _callSite: CallSite = null + private var _startTime: Option[Long] = None + private var _endTime: Option[Long] = None def run() { _result = Try(func()) @@ -70,5 +75,29 @@ class Job(val time: Time, func: () => _) { _outputOpId = outputOpId } + def setCallSite(callSite: CallSite): Unit = { + _callSite = callSite + } + + def callSite: CallSite = _callSite + + def setStartTime(startTime: Long): Unit = { + _startTime = Some(startTime) + } + + def setEndTime(endTime: Long): Unit = { + _endTime = Some(endTime) + } + + def toOutputOperationInfo: OutputOperationInfo = { + val failureReason = if (_result != null && _result.isFailure) { + Some(Utils.exceptionString(_result.asInstanceOf[Failure[_]].exception)) + } else { + None + } + OutputOperationInfo( + time, outputOpId, callSite.shortForm, callSite.longForm, _startTime, _endTime, failureReason) + } + override def toString: String = id } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 0cd39594ee923..2480b4ec093e2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -20,17 +20,18 @@ package org.apache.spark.streaming.scheduler import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.collection.JavaConverters._ -import scala.util.{Failure, Success} +import scala.util.Failure import org.apache.spark.Logging import org.apache.spark.rdd.PairRDDFunctions import org.apache.spark.streaming._ -import org.apache.spark.util.{EventLoop, ThreadUtils} +import org.apache.spark.streaming.ui.UIUtils +import org.apache.spark.util.{EventLoop, ThreadUtils, Utils} private[scheduler] sealed trait JobSchedulerEvent -private[scheduler] case class JobStarted(job: Job) extends JobSchedulerEvent -private[scheduler] case class JobCompleted(job: Job) extends JobSchedulerEvent +private[scheduler] case class JobStarted(job: Job, startTime: Long) extends JobSchedulerEvent +private[scheduler] case class JobCompleted(job: Job, completedTime: Long) extends JobSchedulerEvent private[scheduler] case class ErrorReported(msg: String, e: Throwable) extends JobSchedulerEvent /** @@ -142,8 +143,8 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private def processEvent(event: JobSchedulerEvent) { try { event match { - case JobStarted(job) => handleJobStart(job) - case JobCompleted(job) => handleJobCompletion(job) + case JobStarted(job, startTime) => handleJobStart(job, startTime) + case JobCompleted(job, completedTime) => handleJobCompletion(job, completedTime) case ErrorReported(m, e) => handleError(m, e) } } catch { @@ -152,7 +153,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } } - private def handleJobStart(job: Job) { + private def handleJobStart(job: Job, startTime: Long) { val jobSet = jobSets.get(job.time) val isFirstJobOfJobSet = !jobSet.hasStarted jobSet.handleJobStart(job) @@ -161,26 +162,30 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { // correct "jobSet.processingStartTime". listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo)) } + job.setStartTime(startTime) + listenerBus.post(StreamingListenerOutputOperationStarted(job.toOutputOperationInfo)) logInfo("Starting job " + job.id + " from job set of time " + jobSet.time) } - private def handleJobCompletion(job: Job) { + private def handleJobCompletion(job: Job, completedTime: Long) { + val jobSet = jobSets.get(job.time) + jobSet.handleJobCompletion(job) + job.setEndTime(completedTime) + listenerBus.post(StreamingListenerOutputOperationCompleted(job.toOutputOperationInfo)) + logInfo("Finished job " + job.id + " from job set of time " + jobSet.time) + if (jobSet.hasCompleted) { + jobSets.remove(jobSet.time) + jobGenerator.onBatchCompletion(jobSet.time) + logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format( + jobSet.totalDelay / 1000.0, jobSet.time.toString, + jobSet.processingDelay / 1000.0 + )) + listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo)) + } job.result match { - case Success(_) => - val jobSet = jobSets.get(job.time) - jobSet.handleJobCompletion(job) - logInfo("Finished job " + job.id + " from job set of time " + jobSet.time) - if (jobSet.hasCompleted) { - jobSets.remove(jobSet.time) - jobGenerator.onBatchCompletion(jobSet.time) - logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format( - jobSet.totalDelay / 1000.0, jobSet.time.toString, - jobSet.processingDelay / 1000.0 - )) - listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo)) - } case Failure(e) => reportError("Error running job " + job, e) + case _ => } } @@ -190,16 +195,26 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } private class JobHandler(job: Job) extends Runnable with Logging { + import JobScheduler._ + def run() { - ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString) - ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString) try { + val formattedTime = UIUtils.formatBatchTime( + job.time.milliseconds, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) + val batchUrl = s"/streaming/batch/?id=${job.time.milliseconds}" + val batchLinkText = s"[output operation ${job.outputOpId}, batch time ${formattedTime}]" + + ssc.sc.setJobDescription( + s"""Streaming job from $batchLinkText""") + ssc.sc.setLocalProperty(BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString) + ssc.sc.setLocalProperty(OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString) + // We need to assign `eventLoop` to a temp variable. Otherwise, because // `JobScheduler.stop(false)` may set `eventLoop` to null when this method is running, then // it's possible that when `post` is called, `eventLoop` happens to null. var _eventLoop = eventLoop if (_eventLoop != null) { - _eventLoop.post(JobStarted(job)) + _eventLoop.post(JobStarted(job, clock.getTimeMillis())) // Disable checks for existing output directories in jobs launched by the streaming // scheduler, since we may need to write output to an existing directory during checkpoint // recovery; see SPARK-4835 for more details. @@ -208,7 +223,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } _eventLoop = eventLoop if (_eventLoop != null) { - _eventLoop.post(JobCompleted(job)) + _eventLoop.post(JobCompleted(job, clock.getTimeMillis())) } } else { // JobScheduler has been stopped. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 95833efc9417f..f76300351e3c0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -18,8 +18,10 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable.HashSet +import scala.util.Failure import org.apache.spark.streaming.Time +import org.apache.spark.util.Utils /** Class representing a set of Jobs * belong to the same batch. @@ -62,12 +64,13 @@ case class JobSet( } def toBatchInfo: BatchInfo = { - new BatchInfo( + BatchInfo( time, streamIdToInputInfo, submissionTime, - if (processingStartTime >= 0 ) Some(processingStartTime) else None, - if (processingEndTime >= 0 ) Some(processingEndTime) else None + if (processingStartTime >= 0) Some(processingStartTime) else None, + if (processingEndTime >= 0) Some(processingEndTime) else None, + jobs.map { job => (job.outputOpId, job.toOutputOperationInfo) }.toMap ) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/OutputOperationInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/OutputOperationInfo.scala new file mode 100644 index 0000000000000..137e512a670da --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/OutputOperationInfo.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.streaming.Time + +/** + * :: DeveloperApi :: + * Class having information on output operations. + * @param batchTime Time of the batch + * @param id Id of this output operation. Different output operations have different ids in a batch. + * @param name The name of this output operation. + * @param description The description of this output operation. + * @param startTime Clock time of when the output operation started processing + * @param endTime Clock time of when the output operation started processing + * @param failureReason Failure reason if this output operation fails + */ +@DeveloperApi +case class OutputOperationInfo( + batchTime: Time, + id: Int, + name: String, + description: String, + startTime: Option[Long], + endTime: Option[Long], + failureReason: Option[String]) { + + /** + * Return the duration of this output operation. + */ + def duration: Option[Long] = for (s <- startTime; e <- endTime) yield e - s +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala index 10b5a7f57a802..d2b0be7f4a9c5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala @@ -21,6 +21,7 @@ import scala.collection.Map import scala.collection.mutable import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.util.Utils /** * A class that tries to schedule receivers with evenly distributed. There are two phases for @@ -79,7 +80,7 @@ private[streaming] class ReceiverSchedulingPolicy { return receivers.map(_.streamId -> Seq.empty).toMap } - val hostToExecutors = executors.groupBy(_.split(":")(0)) + val hostToExecutors = executors.groupBy(executor => Utils.parseHostPort(executor)._1) val scheduledExecutors = Array.fill(receivers.length)(new mutable.ArrayBuffer[String]) val numReceiversOnExecutor = mutable.HashMap[String, Int]() // Set the initial value to 0 diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index f86fd44b48719..2ce80d618b0a3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -30,7 +30,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver._ -import org.apache.spark.util.{ThreadUtils, SerializableConfiguration} +import org.apache.spark.util.{Utils, ThreadUtils, SerializableConfiguration} /** Enumeration to identify current state of a Receiver */ @@ -474,7 +474,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Remote messages case RegisterReceiver(streamId, typ, hostPort, receiverEndpoint) => val successful = - registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.sender.address) + registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.senderAddress) context.reply(successful) case AddBlock(receivedBlockInfo) => context.reply(addBlock(receivedBlockInfo)) @@ -551,9 +551,14 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false if (scheduledExecutors.isEmpty) { ssc.sc.makeRDD(Seq(receiver), 1) } else { - ssc.sc.makeRDD(Seq(receiver -> scheduledExecutors)) + val preferredLocations = + scheduledExecutors.map(hostPort => Utils.parseHostPort(hostPort)._1).distinct + ssc.sc.makeRDD(Seq(receiver -> preferredLocations)) } receiverRDD.setName(s"Receiver $receiverId") + ssc.sparkContext.setJobDescription(s"Streaming job running receiver $receiverId") + ssc.sparkContext.setCallSite(Option(ssc.getStartSite()).getOrElse(Utils.getCallSite())) + val future = ssc.sparkContext.submitJob[Receiver[_], Unit, Unit]( receiverRDD, startReceiverFunc, Seq(0), (_, _) => Unit, ()) // We will keep restarting the receiver job until ReceiverTracker is stopped diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala index 74dbba453f026..d19bdbb443c5e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -38,6 +38,14 @@ case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends Streami @DeveloperApi case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent +@DeveloperApi +case class StreamingListenerOutputOperationStarted(outputOperationInfo: OutputOperationInfo) + extends StreamingListenerEvent + +@DeveloperApi +case class StreamingListenerOutputOperationCompleted(outputOperationInfo: OutputOperationInfo) + extends StreamingListenerEvent + @DeveloperApi case class StreamingListenerReceiverStarted(receiverInfo: ReceiverInfo) extends StreamingListenerEvent @@ -75,6 +83,14 @@ trait StreamingListener { /** Called when processing of a batch of jobs has completed. */ def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { } + + /** Called when processing of a job of a batch has started. */ + def onOutputOperationStarted( + outputOperationStarted: StreamingListenerOutputOperationStarted) { } + + /** Called when processing of a job of a batch has completed. */ + def onOutputOperationCompleted( + outputOperationCompleted: StreamingListenerOutputOperationCompleted) { } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index b07d6cf347ca7..ca111bb636ed5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -43,6 +43,10 @@ private[spark] class StreamingListenerBus listener.onBatchStarted(batchStarted) case batchCompleted: StreamingListenerBatchCompleted => listener.onBatchCompleted(batchCompleted) + case outputOperationStarted: StreamingListenerOutputOperationStarted => + listener.onOutputOperationStarted(outputOperationStarted) + case outputOperationCompleted: StreamingListenerOutputOperationCompleted => + listener.onOutputOperationCompleted(outputOperationCompleted) case _ => } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala index f702bd5bc9466..125cafd41b8af 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -17,9 +17,6 @@ package org.apache.spark.streaming.ui -import java.text.SimpleDateFormat -import java.util.Date - import scala.xml.Node import org.apache.spark.ui.{UIUtils => SparkUIUtils} @@ -46,7 +43,8 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) val formattedProcessingTime = processingTime.map(SparkUIUtils.formatDuration).getOrElse("-") val batchTimeId = s"batch-$batchTime" - + {formattedBatchTime} @@ -75,6 +73,19 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) batchTable } + protected def createOutputOperationProgressBar(batch: BatchUIData): Seq[Node] = { + + { + SparkUIUtils.makeProgressBar( + started = batch.numActiveOutputOp, + completed = batch.numCompletedOutputOp, + failed = batch.numFailedOutputOp, + skipped = 0, + total = batch.outputOperations.size) + } + + } + /** * Return HTML for all rows of this table. */ @@ -86,7 +97,10 @@ private[ui] class ActiveBatchTable( waitingBatches: Seq[BatchUIData], batchInterval: Long) extends BatchTableBase("active-batches-table", batchInterval) { - override protected def columns: Seq[Node] = super.columns ++ Status + override protected def columns: Seq[Node] = super.columns ++ { + Output Ops: Succeeded/Total + Status + } override protected def renderRows: Seq[Node] = { // The "batchTime"s of "waitingBatches" must be greater than "runningBatches"'s, so display @@ -96,20 +110,21 @@ private[ui] class ActiveBatchTable( } private def runningBatchRow(batch: BatchUIData): Seq[Node] = { - baseRow(batch) ++ processing + baseRow(batch) ++ createOutputOperationProgressBar(batch) ++ processing } private def waitingBatchRow(batch: BatchUIData): Seq[Node] = { - baseRow(batch) ++ queued + baseRow(batch) ++ createOutputOperationProgressBar(batch) ++ queued } } private[ui] class CompletedBatchTable(batches: Seq[BatchUIData], batchInterval: Long) extends BatchTableBase("completed-batches-table", batchInterval) { - override protected def columns: Seq[Node] = super.columns ++ - Total Delay - {SparkUIUtils.tooltip("Total time taken to handle a batch", "top")} + override protected def columns: Seq[Node] = super.columns ++ { + Total Delay {SparkUIUtils.tooltip("Total time taken to handle a batch", "top")} + Output Ops: Succeeded/Total + } override protected def renderRows: Seq[Node] = { batches.flatMap(batch => {completedBatchRow(batch)}) @@ -118,9 +133,11 @@ private[ui] class CompletedBatchTable(batches: Seq[BatchUIData], batchInterval: private def completedBatchRow(batch: BatchUIData): Seq[Node] = { val totalDelay = batch.totalDelay val formattedTotalDelay = totalDelay.map(SparkUIUtils.formatDuration).getOrElse("-") - baseRow(batch) ++ + + baseRow(batch) ++ { {formattedTotalDelay} + } ++ createOutputOperationProgressBar(batch) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 90d1b0fadecfc..2ed925572826e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -19,14 +19,14 @@ package org.apache.spark.streaming.ui import javax.servlet.http.HttpServletRequest -import scala.xml.{NodeSeq, Node, Text, Unparsed} +import scala.xml._ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.streaming.Time -import org.apache.spark.ui.{UIUtils => SparkUIUtils, WebUIPage} -import org.apache.spark.streaming.ui.StreamingJobProgressListener.{SparkJobId, OutputOpId} +import org.apache.spark.streaming.ui.StreamingJobProgressListener.{OutputOpId, SparkJobId} import org.apache.spark.ui.jobs.UIData.JobUIData +import org.apache.spark.ui.{UIUtils => SparkUIUtils, WebUIPage} private[ui] case class SparkJobIdWithUIData(sparkJobId: SparkJobId, jobUIData: Option[JobUIData]) @@ -38,6 +38,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { Output Op Id Description Duration + Status Job Id Duration Stages: Succeeded/Total @@ -46,27 +47,49 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateJobRow( - outputOpId: OutputOpId, + outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, numSparkJobRowsInOutputOp: Int, isFirstRow: Boolean, sparkJob: SparkJobIdWithUIData): Seq[Node] = { if (sparkJob.jobUIData.isDefined) { - generateNormalJobRow(outputOpId, outputOpDescription, formattedOutputOpDuration, + generateNormalJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, numSparkJobRowsInOutputOp, isFirstRow, sparkJob.jobUIData.get) } else { - generateDroppedJobRow(outputOpId, outputOpDescription, formattedOutputOpDuration, + generateDroppedJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, numSparkJobRowsInOutputOp, isFirstRow, sparkJob.sparkJobId) } } + private def generateOutputOpRowWithoutSparkJobs( + outputOpData: OutputOperationUIData, + outputOpDescription: Seq[Node], + formattedOutputOpDuration: String): Seq[Node] = { + + {outputOpData.id.toString} + {outputOpDescription} + {formattedOutputOpDuration} + {outputOpStatusCell(outputOpData, rowspan = 1)} + + - + + - + + - + + - + + - + + } + /** * Generate a row for a Spark Job. Because duplicated output op infos needs to be collapsed into * one cell, we use "rowspan" for the first row of a output op. */ private def generateNormalJobRow( - outputOpId: OutputOpId, + outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, numSparkJobRowsInOutputOp: Int, @@ -90,11 +113,12 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { // scalastyle:off val prefixCells = if (isFirstRow) { - {outputOpId.toString} + {outputOpData.id.toString} {outputOpDescription} - {formattedOutputOpDuration} + {formattedOutputOpDuration} ++ + {outputOpStatusCell(outputOpData, numSparkJobRowsInOutputOp)} } else { Nil } @@ -125,7 +149,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { total = sparkJob.numTasks - sparkJob.numSkippedTasks) } - {failureReasonCell(lastFailureReason)} + {failureReasonCell(lastFailureReason, rowspan = 1)} } @@ -134,7 +158,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { * with "-" cells. */ private def generateDroppedJobRow( - outputOpId: OutputOpId, + outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, numSparkJobRowsInOutputOp: Int, @@ -145,9 +169,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { // scalastyle:off val prefixCells = if (isFirstRow) { - {outputOpId.toString} + {outputOpData.id.toString} {outputOpDescription} - {formattedOutputOpDuration} + {formattedOutputOpDuration} ++ + {outputOpStatusCell(outputOpData, numSparkJobRowsInOutputOp)} } else { Nil } @@ -156,7 +181,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { {prefixCells} - {jobId.toString} + {if (jobId >= 0) jobId.toString else "-"} - @@ -170,56 +195,60 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateOutputOpIdRow( - outputOpId: OutputOpId, sparkJobs: Seq[SparkJobIdWithUIData]): Seq[Node] = { - // We don't count the durations of dropped jobs - val sparkJobDurations = sparkJobs.filter(_.jobUIData.nonEmpty).map(_.jobUIData.get). - map(sparkJob => { - sparkJob.submissionTime.map { start => - val end = sparkJob.completionTime.getOrElse(System.currentTimeMillis()) - end - start - } - }) + outputOpData: OutputOperationUIData, + sparkJobs: Seq[SparkJobIdWithUIData]): Seq[Node] = { val formattedOutputOpDuration = - if (sparkJobDurations.isEmpty || sparkJobDurations.exists(_ == None)) { - // If no job or any job does not finish, set "formattedOutputOpDuration" to "-" + if (outputOpData.duration.isEmpty) { "-" } else { - SparkUIUtils.formatDuration(sparkJobDurations.flatMap(x => x).sum) + SparkUIUtils.formatDuration(outputOpData.duration.get) } - val description = generateOutputOpDescription(sparkJobs) + val description = generateOutputOpDescription(outputOpData) - generateJobRow( - outputOpId, description, formattedOutputOpDuration, sparkJobs.size, true, sparkJobs.head) ++ - sparkJobs.tail.map { sparkJob => + if (sparkJobs.isEmpty) { + generateOutputOpRowWithoutSparkJobs(outputOpData, description, formattedOutputOpDuration) + } else { + val firstRow = generateJobRow( - outputOpId, description, formattedOutputOpDuration, sparkJobs.size, false, sparkJob) - }.flatMap(x => x) - } - - private def generateOutputOpDescription(sparkJobs: Seq[SparkJobIdWithUIData]): Seq[Node] = { - val lastStageInfo = - sparkJobs.flatMap(_.jobUIData).headOption. // Get the first JobUIData - flatMap { sparkJob => // For the first job, get the latest Stage info - if (sparkJob.stageIds.isEmpty) { - None - } else { - sparkListener.stageIdToInfo.get(sparkJob.stageIds.max) - } + outputOpData, + description, + formattedOutputOpDuration, + sparkJobs.size, + true, + sparkJobs.head) + val tailRows = + sparkJobs.tail.map { sparkJob => + generateJobRow( + outputOpData, + description, + formattedOutputOpDuration, + sparkJobs.size, + false, + sparkJob) } - val lastStageData = lastStageInfo.flatMap { s => - sparkListener.stageIdToData.get((s.stageId, s.attemptId)) + (firstRow ++ tailRows).flatten } + } - val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") - val lastStageDescription = lastStageData.flatMap(_.description).getOrElse("") - - - {lastStageDescription} - ++ Text(lastStageName) + private def generateOutputOpDescription(outputOp: OutputOperationUIData): Seq[Node] = { +

+ {outputOp.name} + + +details + + +
} - private def failureReasonCell(failureReason: String): Seq[Node] = { + private def failureReasonCell( + failureReason: String, + rowspan: Int, + includeFirstLineInExpandDetails: Boolean = true): Seq[Node] = { val isMultiline = failureReason.indexOf('\n') >= 0 // Display the first line by default val failureReasonSummary = StringEscapeUtils.escapeHtml4( @@ -228,6 +257,13 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } else { failureReason }) + val failureDetails = + if (isMultiline && !includeFirstLineInExpandDetails) { + // Skip the first line + failureReason.substring(failureReason.indexOf('\n') + 1) + } else { + failureReason + } val details = if (isMultiline) { // scalastyle:off ++ // scalastyle:on } else { "" } - {failureReasonSummary}{details} + + if (rowspan == 1) { + {failureReasonSummary}{details} + } else { + + {failureReasonSummary}{details} + + } } private def getJobData(sparkJobId: SparkJobId): Option[JobUIData] = { @@ -252,20 +295,37 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } } + private def generateOutputOperationStatusForUI(failure: String): String = { + if (failure.startsWith("org.apache.spark.SparkException")) { + "Failed due to Spark job error\n" + failure + } else { + var nextLineIndex = failure.indexOf("\n") + if (nextLineIndex < 0) { + nextLineIndex = failure.size + } + val firstLine = failure.substring(0, nextLineIndex) + s"Failed due to error: $firstLine\n$failure" + } + } + /** * Generate the job table for the batch. */ private def generateJobTable(batchUIData: BatchUIData): Seq[Node] = { - val outputOpIdToSparkJobIds = batchUIData.outputOpIdSparkJobIdPairs.groupBy(_.outputOpId).toSeq. - sortBy(_._1). // sorted by OutputOpId + val outputOpIdToSparkJobIds = batchUIData.outputOpIdSparkJobIdPairs.groupBy(_.outputOpId). map { case (outputOpId, outputOpIdAndSparkJobIds) => // sort SparkJobIds for each OutputOpId (outputOpId, outputOpIdAndSparkJobIds.map(_.sparkJobId).sorted) } + + val outputOps: Seq[(OutputOperationUIData, Seq[SparkJobId])] = + batchUIData.outputOperations.map { case (outputOpId, outputOperation) => + val sparkJobIds = outputOpIdToSparkJobIds.getOrElse(outputOpId, Seq.empty) + (outputOperation, sparkJobIds) + }.toSeq.sortBy(_._1.id) sparkListener.synchronized { - val outputOpIdWithJobs: Seq[(OutputOpId, Seq[SparkJobIdWithUIData])] = - outputOpIdToSparkJobIds.map { case (outputOpId, sparkJobIds) => - (outputOpId, + val outputOpWithJobs = outputOps.map { case (outputOpData, sparkJobIds) => + (outputOpData, sparkJobIds.map(sparkJobId => SparkJobIdWithUIData(sparkJobId, getJobData(sparkJobId)))) } @@ -275,8 +335,8 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { { - outputOpIdWithJobs.map { - case (outputOpId, sparkJobIds) => generateOutputOpIdRow(outputOpId, sparkJobIds) + outputOpWithJobs.map { case (outputOpData, sparkJobIds) => + generateOutputOpIdRow(outputOpData, sparkJobIds) } } @@ -284,7 +344,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } } - def render(request: HttpServletRequest): Seq[Node] = { + def render(request: HttpServletRequest): Seq[Node] = streamingListener.synchronized { val batchTime = Option(request.getParameter("id")).map(id => Time(id.toLong)).getOrElse { throw new IllegalArgumentException(s"Missing id parameter") } @@ -337,20 +397,13 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { - val jobTable = - if (batchUIData.outputOpIdSparkJobIdPairs.isEmpty) { -
Cannot find any job for Batch {formattedBatchTime}.
- } else { - generateJobTable(batchUIData) - } - - val content = summary ++ jobTable + val content = summary ++ generateJobTable(batchUIData) SparkUIUtils.headerSparkPage(s"Details of batch at $formattedBatchTime", content, parent) } def generateInputMetadataTable(inputMetadatas: Seq[(Int, String)]): Seq[Node] = { - +
@@ -377,4 +430,18 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { Unparsed(StringEscapeUtils.escapeHtml4(metadataDescription). replaceAllLiterally("\t", "    ").replaceAllLiterally("\n", "
")) } + + private def outputOpStatusCell(outputOp: OutputOperationUIData, rowspan: Int): Seq[Node] = { + outputOp.failureReason match { + case Some(failureReason) => + val failureReasonForUI = generateOutputOperationStatusForUI(failureReason) + failureReasonCell(failureReasonForUI, rowspan, includeFirstLineInExpandDetails = false) + case None => + if (outputOp.endTime.isEmpty) { + + } else { + + } + } + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala index ae508c0e9577b..3ef3689de1c45 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala @@ -18,8 +18,10 @@ package org.apache.spark.streaming.ui +import scala.collection.mutable + import org.apache.spark.streaming.Time -import org.apache.spark.streaming.scheduler.{BatchInfo, StreamInputInfo} +import org.apache.spark.streaming.scheduler.{BatchInfo, OutputOperationInfo, StreamInputInfo} import org.apache.spark.streaming.ui.StreamingJobProgressListener._ private[ui] case class OutputOpIdAndSparkJobId(outputOpId: OutputOpId, sparkJobId: SparkJobId) @@ -30,6 +32,7 @@ private[ui] case class BatchUIData( val submissionTime: Long, val processingStartTime: Option[Long], val processingEndTime: Option[Long], + val outputOperations: mutable.HashMap[OutputOpId, OutputOperationUIData] = mutable.HashMap(), var outputOpIdSparkJobIdPairs: Seq[OutputOpIdAndSparkJobId] = Seq.empty) { /** @@ -59,17 +62,75 @@ private[ui] case class BatchUIData( * The number of recorders received by the receivers in this batch. */ def numRecords: Long = streamIdToInputInfo.values.map(_.numRecords).sum + + /** + * Update an output operation information of this batch. + */ + def updateOutputOperationInfo(outputOperationInfo: OutputOperationInfo): Unit = { + assert(batchTime == outputOperationInfo.batchTime) + outputOperations(outputOperationInfo.id) = OutputOperationUIData(outputOperationInfo) + } + + /** + * Return the number of failed output operations. + */ + def numFailedOutputOp: Int = outputOperations.values.count(_.failureReason.nonEmpty) + + /** + * Return the number of running output operations. + */ + def numActiveOutputOp: Int = outputOperations.values.count(_.endTime.isEmpty) + + /** + * Return the number of completed output operations. + */ + def numCompletedOutputOp: Int = outputOperations.values.count { + op => op.failureReason.isEmpty && op.endTime.nonEmpty + } + + /** + * Return if this batch has any output operations + */ + def isFailed: Boolean = numFailedOutputOp != 0 } private[ui] object BatchUIData { def apply(batchInfo: BatchInfo): BatchUIData = { + val outputOperations = mutable.HashMap[OutputOpId, OutputOperationUIData]() + outputOperations ++= batchInfo.outputOperationInfos.mapValues(OutputOperationUIData.apply) new BatchUIData( batchInfo.batchTime, batchInfo.streamIdToInputInfo, batchInfo.submissionTime, batchInfo.processingStartTime, - batchInfo.processingEndTime + batchInfo.processingEndTime, + outputOperations + ) + } +} + +private[ui] case class OutputOperationUIData( + id: OutputOpId, + name: String, + description: String, + startTime: Option[Long], + endTime: Option[Long], + failureReason: Option[String]) { + + def duration: Option[Long] = for (s <- startTime; e <- endTime) yield e - s +} + +private[ui] object OutputOperationUIData { + + def apply(outputOperationInfo: OutputOperationInfo): OutputOperationUIData = { + OutputOperationUIData( + outputOperationInfo.id, + outputOperationInfo.name, + outputOperationInfo.description, + outputOperationInfo.startTime, + outputOperationInfo.endTime, + outputOperationInfo.failureReason ) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 78aeb004e18b1..f6cc6edf2569a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -119,6 +119,20 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } + override def onOutputOperationStarted( + outputOperationStarted: StreamingListenerOutputOperationStarted): Unit = synchronized { + // This method is called after onBatchStarted + runningBatchUIData(outputOperationStarted.outputOperationInfo.batchTime). + updateOutputOperationInfo(outputOperationStarted.outputOperationInfo) + } + + override def onOutputOperationCompleted( + outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = synchronized { + // This method is called before onBatchCompleted + runningBatchUIData(outputOperationCompleted.outputOperationInfo.batchTime). + updateOutputOperationInfo(outputOperationCompleted.outputOperationInfo) + } + override def onJobStart(jobStart: SparkListenerJobStart): Unit = synchronized { getBatchTimeAndOutputOpId(jobStart.properties).foreach { case (batchTime, outputOpId) => var outputOpIdToSparkJobIds = batchTimeToOutputOpIdSparkJobIdPair.get(batchTime) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index dd32ad5ad811d..0148cb51c6f09 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -72,8 +72,10 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: /** * Stop the timer, and return the last time the callback was made. - * interruptTimer = true will interrupt the callback + * - interruptTimer = true will interrupt the callback * if it is in progress (not guaranteed to give correct time in this case). + * - interruptTimer = false guarantees that there will be at least one callback after `stop` has + * been called. */ def stop(interruptTimer: Boolean): Long = synchronized { if (!stopped) { @@ -87,18 +89,23 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: prevTime } + private def triggerActionForNextInterval(): Unit = { + clock.waitTillTime(nextTime) + callback(nextTime) + prevTime = nextTime + nextTime += period + logDebug("Callback for " + name + " called at time " + prevTime) + } + /** * Repeatedly call the callback every interval. */ private def loop() { try { while (!stopped) { - clock.waitTillTime(nextTime) - callback(nextTime) - prevTime = nextTime - nextTime += period - logDebug("Callback for " + name + " called at time " + prevTime) + triggerActionForNextInterval() } + triggerActionForNextInterval() } catch { case e: InterruptedException => } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index e0718f73aa13f..c5217149224e4 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -18,24 +18,22 @@ package org.apache.spark.streaming; import java.io.*; -import java.lang.Iterable; import java.nio.charset.Charset; import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; +import scala.Tuple2; + +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; -import scala.Tuple2; - import org.junit.Assert; -import static org.junit.Assert.*; import org.junit.Test; import com.google.common.base.Optional; -import com.google.common.collect.Lists; import com.google.common.io.Files; import com.google.common.collect.Sets; @@ -54,14 +52,14 @@ // see http://stackoverflow.com/questions/758570/. public class JavaAPISuite extends LocalJavaStreamingContext implements Serializable { - public void equalIterator(Iterator a, Iterator b) { + public static void equalIterator(Iterator a, Iterator b) { while (a.hasNext() && b.hasNext()) { Assert.assertEquals(a.next(), b.next()); } Assert.assertEquals(a.hasNext(), b.hasNext()); } - public void equalIterable(Iterable a, Iterable b) { + public static void equalIterable(Iterable a, Iterable b) { equalIterator(a.iterator(), b.iterator()); } @@ -74,14 +72,14 @@ public void testInitialization() { @Test public void testContextState() { List> inputData = Arrays.asList(Arrays.asList(1, 2, 3, 4)); - Assert.assertTrue(ssc.getState() == StreamingContextState.INITIALIZED); + Assert.assertEquals(StreamingContextState.INITIALIZED, ssc.getState()); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaTestUtils.attachTestOutputStream(stream); - Assert.assertTrue(ssc.getState() == StreamingContextState.INITIALIZED); + Assert.assertEquals(StreamingContextState.INITIALIZED, ssc.getState()); ssc.start(); - Assert.assertTrue(ssc.getState() == StreamingContextState.ACTIVE); + Assert.assertEquals(StreamingContextState.ACTIVE, ssc.getState()); ssc.stop(); - Assert.assertTrue(ssc.getState() == StreamingContextState.STOPPED); + Assert.assertEquals(StreamingContextState.STOPPED, ssc.getState()); } @SuppressWarnings("unchecked") @@ -118,7 +116,7 @@ public void testMap() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -180,7 +178,7 @@ public void testWindowWithSlideDuration() { public void testFilter() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List> expected = Arrays.asList( Arrays.asList("giants"), @@ -189,7 +187,7 @@ public void testFilter() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream filtered = stream.filter(new Function() { @Override - public Boolean call(String s) throws Exception { + public Boolean call(String s) { return s.contains("a"); } }); @@ -243,11 +241,11 @@ public void testRepartitionFewerPartitions() { public void testGlom() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List>> expected = Arrays.asList( Arrays.asList(Arrays.asList("giants", "dodgers")), - Arrays.asList(Arrays.asList("yankees", "red socks"))); + Arrays.asList(Arrays.asList("yankees", "red sox"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream> glommed = stream.glom(); @@ -262,22 +260,22 @@ public void testGlom() { public void testMapPartitions() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List> expected = Arrays.asList( Arrays.asList("GIANTSDODGERS"), - Arrays.asList("YANKEESRED SOCKS")); + Arrays.asList("YANKEESRED SOX")); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream mapped = stream.mapPartitions( new FlatMapFunction, String>() { @Override public Iterable call(Iterator in) { - String out = ""; + StringBuilder out = new StringBuilder(); while (in.hasNext()) { - out = out + in.next().toUpperCase(); + out.append(in.next().toUpperCase(Locale.ENGLISH)); } - return Lists.newArrayList(out); + return Arrays.asList(out.toString()); } }); JavaTestUtils.attachTestOutputStream(mapped); @@ -286,16 +284,16 @@ public Iterable call(Iterator in) { Assert.assertEquals(expected, result); } - private class IntegerSum implements Function2 { + private static class IntegerSum implements Function2 { @Override - public Integer call(Integer i1, Integer i2) throws Exception { + public Integer call(Integer i1, Integer i2) { return i1 + i2; } } - private class IntegerDifference implements Function2 { + private static class IntegerDifference implements Function2 { @Override - public Integer call(Integer i1, Integer i2) throws Exception { + public Integer call(Integer i1, Integer i2) { return i1 - i2; } } @@ -347,13 +345,13 @@ private void testReduceByWindow(boolean withInverse) { Arrays.asList(24)); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = null; + JavaDStream reducedWindowed; if (withInverse) { reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new IntegerDifference(), new Duration(2000), new Duration(1000)); + new IntegerDifference(), new Duration(2000), new Duration(1000)); } else { reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); } JavaTestUtils.attachTestOutputStream(reducedWindowed); List> result = JavaTestUtils.runStreams(ssc, 4, 4); @@ -378,11 +376,11 @@ public void testQueueStream() { Arrays.asList(7,8,9)); JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc()); - JavaRDD rdd1 = ssc.sparkContext().parallelize(Arrays.asList(1, 2, 3)); - JavaRDD rdd2 = ssc.sparkContext().parallelize(Arrays.asList(4, 5, 6)); - JavaRDD rdd3 = ssc.sparkContext().parallelize(Arrays.asList(7,8,9)); + JavaRDD rdd1 = jsc.parallelize(Arrays.asList(1, 2, 3)); + JavaRDD rdd2 = jsc.parallelize(Arrays.asList(4, 5, 6)); + JavaRDD rdd3 = jsc.parallelize(Arrays.asList(7,8,9)); - LinkedList> rdds = Lists.newLinkedList(); + Queue> rdds = new LinkedList<>(); rdds.add(rdd1); rdds.add(rdd2); rdds.add(rdd3); @@ -410,10 +408,10 @@ public void testTransform() { JavaDStream transformed = stream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaRDD in) throws Exception { + public JavaRDD call(JavaRDD in) { return in.map(new Function() { @Override - public Integer call(Integer i) throws Exception { + public Integer call(Integer i) { return i + 2; } }); @@ -435,70 +433,70 @@ public void testVariousTransform() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); List>> pairInputData = - Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData, 1)); - JavaDStream transformed1 = stream.transform( + stream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaRDD in) throws Exception { + public JavaRDD call(JavaRDD in) { return null; } } ); - JavaDStream transformed2 = stream.transform( + stream.transform( new Function2, Time, JavaRDD>() { - @Override public JavaRDD call(JavaRDD in, Time time) throws Exception { + @Override public JavaRDD call(JavaRDD in, Time time) { return null; } } ); - JavaPairDStream transformed3 = stream.transformToPair( + stream.transformToPair( new Function, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaRDD in) throws Exception { + @Override public JavaPairRDD call(JavaRDD in) { return null; } } ); - JavaPairDStream transformed4 = stream.transformToPair( + stream.transformToPair( new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaRDD in, Time time) throws Exception { + @Override public JavaPairRDD call(JavaRDD in, Time time) { return null; } } ); - JavaDStream pairTransformed1 = pairStream.transform( + pairStream.transform( new Function, JavaRDD>() { - @Override public JavaRDD call(JavaPairRDD in) throws Exception { + @Override public JavaRDD call(JavaPairRDD in) { return null; } } ); - JavaDStream pairTransformed2 = pairStream.transform( + pairStream.transform( new Function2, Time, JavaRDD>() { - @Override public JavaRDD call(JavaPairRDD in, Time time) throws Exception { + @Override public JavaRDD call(JavaPairRDD in, Time time) { return null; } } ); - JavaPairDStream pairTransformed3 = pairStream.transformToPair( + pairStream.transformToPair( new Function, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in) throws Exception { + @Override public JavaPairRDD call(JavaPairRDD in) { return null; } } ); - JavaPairDStream pairTransformed4 = pairStream.transformToPair( + pairStream.transformToPair( new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in, Time time) throws Exception { + @Override public JavaPairRDD call(JavaPairRDD in, Time time) { return null; } } @@ -511,32 +509,32 @@ public JavaRDD call(JavaRDD in) throws Exception { public void testTransformWith() { List>> stringStringKVStream1 = Arrays.asList( Arrays.asList( - new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), + new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), Arrays.asList( - new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( Arrays.asList( - new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), Arrays.asList( - new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List>>> expected = Arrays.asList( Sets.newHashSet( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), + new Tuple2<>("california", + new Tuple2<>("dodgers", "giants")), + new Tuple2<>("new york", + new Tuple2<>("yankees", "mets"))), Sets.newHashSet( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); + new Tuple2<>("california", + new Tuple2<>("sharks", "ducks")), + new Tuple2<>("new york", + new Tuple2<>("rangers", "islanders")))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( ssc, stringStringKVStream1, 1); @@ -552,14 +550,12 @@ public void testTransformWith() { JavaPairRDD, JavaPairRDD, Time, - JavaPairRDD> - >() { + JavaPairRDD>>() { @Override public JavaPairRDD> call( JavaPairRDD rdd1, JavaPairRDD rdd2, - Time time - ) throws Exception { + Time time) { return rdd1.join(rdd2); } } @@ -567,9 +563,9 @@ public JavaPairRDD> call( JavaTestUtils.attachTestOutputStream(joined); List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); - List>>> unorderedResult = Lists.newArrayList(); + List>>> unorderedResult = new ArrayList<>(); for (List>> res: result) { - unorderedResult.add(Sets.newHashSet(res)); + unorderedResult.add(Sets.newHashSet(res)); } Assert.assertEquals(expected, unorderedResult); @@ -587,89 +583,89 @@ public void testVariousTransformWith() { JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 1); List>> pairInputData1 = - Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); List>> pairInputData2 = - Arrays.asList(Arrays.asList(new Tuple2(1.0, 'x'))); + Arrays.asList(Arrays.asList(new Tuple2<>(1.0, 'x'))); JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData1, 1)); JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData2, 1)); - JavaDStream transformed1 = stream1.transformWith( + stream1.transformWith( stream2, new Function3, JavaRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaDStream transformed2 = stream1.transformWith( + stream1.transformWith( pairStream1, new Function3, JavaPairRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaPairDStream transformed3 = stream1.transformWithToPair( + stream1.transformWithToPair( stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaPairDStream transformed4 = stream1.transformWithToPair( + stream1.transformWithToPair( pairStream1, new Function3, JavaPairRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaDStream pairTransformed1 = pairStream1.transformWith( + pairStream1.transformWith( stream2, new Function3, JavaRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaDStream pairTransformed2_ = pairStream1.transformWith( + pairStream1.transformWith( pairStream1, new Function3, JavaPairRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaPairDStream pairTransformed3 = pairStream1.transformWithToPair( + pairStream1.transformWithToPair( stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaPairDStream pairTransformed4 = pairStream1.transformWithToPair( + pairStream1.transformWithToPair( pairStream2, new Function3, JavaPairRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } @@ -690,13 +686,13 @@ public void testStreamingContextTransform(){ ); List>> pairStream1input = Arrays.asList( - Arrays.asList(new Tuple2(1, "x")), - Arrays.asList(new Tuple2(2, "y")) + Arrays.asList(new Tuple2<>(1, "x")), + Arrays.asList(new Tuple2<>(2, "y")) ); List>>> expected = Arrays.asList( - Arrays.asList(new Tuple2>(1, new Tuple2(1, "x"))), - Arrays.asList(new Tuple2>(2, new Tuple2(2, "y"))) + Arrays.asList(new Tuple2<>(1, new Tuple2<>(1, "x"))), + Arrays.asList(new Tuple2<>(2, new Tuple2<>(2, "y"))) ); JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, stream1input, 1); @@ -707,7 +703,7 @@ public void testStreamingContextTransform(){ List> listOfDStreams1 = Arrays.>asList(stream1, stream2); // This is just to test whether this transform to JavaStream compiles - JavaDStream transformed1 = ssc.transform( + ssc.transform( listOfDStreams1, new Function2>, Time, JavaRDD>() { @Override @@ -733,8 +729,8 @@ public JavaPairRDD> call(List> listO JavaPairRDD prdd3 = JavaPairRDD.fromJavaRDD(rdd3); PairFunction mapToTuple = new PairFunction() { @Override - public Tuple2 call(Integer i) throws Exception { - return new Tuple2(i, i); + public Tuple2 call(Integer i) { + return new Tuple2<>(i, i); } }; return rdd1.union(rdd2).mapToPair(mapToTuple).join(prdd3); @@ -763,7 +759,7 @@ public void testFlatMap() { JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { @Override public Iterable call(String x) { - return Lists.newArrayList(x.split("(?!^)")); + return Arrays.asList(x.split("(?!^)")); } }); JavaTestUtils.attachTestOutputStream(flatMapped); @@ -782,39 +778,39 @@ public void testPairFlatMap() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(6, "g"), - new Tuple2(6, "i"), - new Tuple2(6, "a"), - new Tuple2(6, "n"), - new Tuple2(6, "t"), - new Tuple2(6, "s")), + new Tuple2<>(6, "g"), + new Tuple2<>(6, "i"), + new Tuple2<>(6, "a"), + new Tuple2<>(6, "n"), + new Tuple2<>(6, "t"), + new Tuple2<>(6, "s")), Arrays.asList( - new Tuple2(7, "d"), - new Tuple2(7, "o"), - new Tuple2(7, "d"), - new Tuple2(7, "g"), - new Tuple2(7, "e"), - new Tuple2(7, "r"), - new Tuple2(7, "s")), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "o"), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "g"), + new Tuple2<>(7, "e"), + new Tuple2<>(7, "r"), + new Tuple2<>(7, "s")), Arrays.asList( - new Tuple2(9, "a"), - new Tuple2(9, "t"), - new Tuple2(9, "h"), - new Tuple2(9, "l"), - new Tuple2(9, "e"), - new Tuple2(9, "t"), - new Tuple2(9, "i"), - new Tuple2(9, "c"), - new Tuple2(9, "s"))); + new Tuple2<>(9, "a"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "h"), + new Tuple2<>(9, "l"), + new Tuple2<>(9, "e"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "i"), + new Tuple2<>(9, "c"), + new Tuple2<>(9, "s"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream flatMapped = stream.flatMapToPair( new PairFlatMapFunction() { @Override - public Iterable> call(String in) throws Exception { - List> out = Lists.newArrayList(); + public Iterable> call(String in) { + List> out = new ArrayList<>(); for (String letter: in.split("(?!^)")) { - out.add(new Tuple2(in.length(), letter)); + out.add(new Tuple2<>(in.length(), letter)); } return out; } @@ -859,13 +855,13 @@ public void testUnion() { */ public static void assertOrderInvariantEquals( List> expected, List> actual) { - List> expectedSets = new ArrayList>(); + List> expectedSets = new ArrayList<>(); for (List list: expected) { - expectedSets.add(Collections.unmodifiableSet(new HashSet(list))); + expectedSets.add(Collections.unmodifiableSet(new HashSet<>(list))); } - List> actualSets = new ArrayList>(); + List> actualSets = new ArrayList<>(); for (List list: actual) { - actualSets.add(Collections.unmodifiableSet(new HashSet(list))); + actualSets.add(Collections.unmodifiableSet(new HashSet<>(list))); } Assert.assertEquals(expectedSets, actualSets); } @@ -877,25 +873,25 @@ public static void assertOrderInvariantEquals( public void testPairFilter() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("giants", 6)), - Arrays.asList(new Tuple2("yankees", 7))); + Arrays.asList(new Tuple2<>("giants", 6)), + Arrays.asList(new Tuple2<>("yankees", 7))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = stream.mapToPair( new PairFunction() { @Override - public Tuple2 call(String in) throws Exception { - return new Tuple2(in, in.length()); + public Tuple2 call(String in) { + return new Tuple2<>(in, in.length()); } }); JavaPairDStream filtered = pairStream.filter( new Function, Boolean>() { @Override - public Boolean call(Tuple2 in) throws Exception { + public Boolean call(Tuple2 in) { return in._1().contains("a"); } }); @@ -906,28 +902,28 @@ public Boolean call(Tuple2 in) throws Exception { } @SuppressWarnings("unchecked") - private List>> stringStringKVStream = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("california", "giants"), - new Tuple2("new york", "yankees"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("california", "ducks"), - new Tuple2("new york", "rangers"), - new Tuple2("new york", "islanders"))); + private final List>> stringStringKVStream = Arrays.asList( + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "yankees"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "rangers"), + new Tuple2<>("new york", "islanders"))); @SuppressWarnings("unchecked") - private List>> stringIntKVStream = Arrays.asList( + private final List>> stringIntKVStream = Arrays.asList( Arrays.asList( - new Tuple2("california", 1), - new Tuple2("california", 3), - new Tuple2("new york", 4), - new Tuple2("new york", 1)), + new Tuple2<>("california", 1), + new Tuple2<>("california", 3), + new Tuple2<>("new york", 4), + new Tuple2<>("new york", 1)), Arrays.asList( - new Tuple2("california", 5), - new Tuple2("california", 5), - new Tuple2("new york", 3), - new Tuple2("new york", 1))); + new Tuple2<>("california", 5), + new Tuple2<>("california", 5), + new Tuple2<>("new york", 3), + new Tuple2<>("new york", 1))); @SuppressWarnings("unchecked") @Test @@ -936,22 +932,22 @@ public void testPairMap() { // Maps pair -> pair of different type List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapToPair( new PairFunction, Integer, String>() { @Override - public Tuple2 call(Tuple2 in) throws Exception { + public Tuple2 call(Tuple2 in) { return in.swap(); } }); @@ -969,23 +965,23 @@ public void testPairMapPartitions() { // Maps pair -> pair of different type List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapPartitionsToPair( new PairFlatMapFunction>, Integer, String>() { @Override - public Iterable> call(Iterator> in) throws Exception { - LinkedList> out = new LinkedList>(); + public Iterable> call(Iterator> in) { + List> out = new LinkedList<>(); while (in.hasNext()) { Tuple2 next = in.next(); out.add(next.swap()); @@ -1014,7 +1010,7 @@ public void testPairMap2() { // Maps pair -> single JavaDStream reversed = pairStream.map( new Function, Integer>() { @Override - public Integer call(Tuple2 in) throws Exception { + public Integer call(Tuple2 in) { return in._2(); } }); @@ -1030,23 +1026,23 @@ public Integer call(Tuple2 in) throws Exception { public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2)), + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2)), Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2))); + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2))); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o")), + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o")), Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o"))); + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); @@ -1054,10 +1050,10 @@ public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair JavaPairDStream flatMapped = pairStream.flatMapToPair( new PairFlatMapFunction, Integer, String>() { @Override - public Iterable> call(Tuple2 in) throws Exception { - List> out = new LinkedList>(); + public Iterable> call(Tuple2 in) { + List> out = new LinkedList<>(); for (Character s : in._1().toCharArray()) { - out.add(new Tuple2(in._2(), s.toString())); + out.add(new Tuple2<>(in._2(), s.toString())); } return out; } @@ -1075,11 +1071,11 @@ public void testPairGroupByKey() { List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", Arrays.asList("dodgers", "giants")), - new Tuple2>("new york", Arrays.asList("yankees", "mets"))), + new Tuple2<>("california", Arrays.asList("dodgers", "giants")), + new Tuple2<>("new york", Arrays.asList("yankees", "mets"))), Arrays.asList( - new Tuple2>("california", Arrays.asList("sharks", "ducks")), - new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); + new Tuple2<>("california", Arrays.asList("sharks", "ducks")), + new Tuple2<>("new york", Arrays.asList("rangers", "islanders")))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1111,11 +1107,11 @@ public void testPairReduceByKey() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1136,20 +1132,20 @@ public void testCombineByKey() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream combined = pairStream.combineByKey( + JavaPairDStream combined = pairStream.combineByKey( new Function() { @Override - public Integer call(Integer i) throws Exception { + public Integer call(Integer i) { return i; } }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); @@ -1170,13 +1166,13 @@ public void testCountByValue() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("world", 1L)), Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("moon", 1L)), Arrays.asList( - new Tuple2("hello", 1L))); + new Tuple2<>("hello", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream counted = stream.countByValue(); @@ -1193,16 +1189,16 @@ public void testGroupByKeyAndWindow() { List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", Arrays.asList(1, 3)), - new Tuple2>("new york", Arrays.asList(1, 4)) + new Tuple2<>("california", Arrays.asList(1, 3)), + new Tuple2<>("new york", Arrays.asList(1, 4)) ), Arrays.asList( - new Tuple2>("california", Arrays.asList(1, 3, 5, 5)), - new Tuple2>("new york", Arrays.asList(1, 1, 3, 4)) + new Tuple2<>("california", Arrays.asList(1, 3, 5, 5)), + new Tuple2<>("new york", Arrays.asList(1, 1, 3, 4)) ), Arrays.asList( - new Tuple2>("california", Arrays.asList(5, 5)), - new Tuple2>("new york", Arrays.asList(1, 3)) + new Tuple2<>("california", Arrays.asList(5, 5)), + new Tuple2<>("new york", Arrays.asList(1, 3)) ) ); @@ -1220,16 +1216,16 @@ public void testGroupByKeyAndWindow() { } } - private HashSet>> convert(List>> listOfTuples) { - List>> newListOfTuples = new ArrayList>>(); + private static Set>> convert(List>> listOfTuples) { + List>> newListOfTuples = new ArrayList<>(); for (Tuple2> tuple: listOfTuples) { newListOfTuples.add(convert(tuple)); } - return new HashSet>>(newListOfTuples); + return new HashSet<>(newListOfTuples); } - private Tuple2> convert(Tuple2> tuple) { - return new Tuple2>(tuple._1(), new HashSet(tuple._2())); + private static Tuple2> convert(Tuple2> tuple) { + return new Tuple2<>(tuple._1(), new HashSet<>(tuple._2())); } @SuppressWarnings("unchecked") @@ -1238,12 +1234,12 @@ public void testReduceByKeyAndWindow() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1262,12 +1258,12 @@ public void testUpdateStateByKey() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1278,10 +1274,10 @@ public void testUpdateStateByKey() { public Optional call(List values, Optional state) { int out = 0; if (state.isPresent()) { - out = out + state.get(); + out += state.get(); } for (Integer v : values) { - out = out + v; + out += v; } return Optional.of(out); } @@ -1298,19 +1294,19 @@ public void testUpdateStateByKeyWithInitial() { List>> inputData = stringIntKVStream; List> initial = Arrays.asList ( - new Tuple2 ("california", 1), - new Tuple2 ("new york", 2)); + new Tuple2<>("california", 1), + new Tuple2<>("new york", 2)); JavaRDD> tmpRDD = ssc.sparkContext().parallelize(initial); JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD); List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 5), - new Tuple2("new york", 7)), - Arrays.asList(new Tuple2("california", 15), - new Tuple2("new york", 11)), - Arrays.asList(new Tuple2("california", 15), - new Tuple2("new york", 11))); + Arrays.asList(new Tuple2<>("california", 5), + new Tuple2<>("new york", 7)), + Arrays.asList(new Tuple2<>("california", 15), + new Tuple2<>("new york", 11)), + Arrays.asList(new Tuple2<>("california", 15), + new Tuple2<>("new york", 11))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1321,10 +1317,10 @@ public void testUpdateStateByKeyWithInitial() { public Optional call(List values, Optional state) { int out = 0; if (state.isPresent()) { - out = out + state.get(); + out += state.get(); } for (Integer v : values) { - out = out + v; + out += v; } return Optional.of(out); } @@ -1341,19 +1337,19 @@ public void testReduceByKeyAndWindowWithInverse() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1370,15 +1366,15 @@ public void testCountByValueAndWindow() { List>> expected = Arrays.asList( Sets.newHashSet( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("world", 1L)), Sets.newHashSet( - new Tuple2("hello", 2L), - new Tuple2("world", 1L), - new Tuple2("moon", 1L)), + new Tuple2<>("hello", 2L), + new Tuple2<>("world", 1L), + new Tuple2<>("moon", 1L)), Sets.newHashSet( - new Tuple2("hello", 2L), - new Tuple2("moon", 1L))); + new Tuple2<>("hello", 2L), + new Tuple2<>("moon", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1386,7 +1382,7 @@ public void testCountByValueAndWindow() { stream.countByValueAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(counted); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - List>> unorderedResult = Lists.newArrayList(); + List>> unorderedResult = new ArrayList<>(); for (List> res: result) { unorderedResult.add(Sets.newHashSet(res)); } @@ -1399,27 +1395,27 @@ public void testCountByValueAndWindow() { public void testPairTransform() { List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5)), + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5)), Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5))); + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1428,7 +1424,7 @@ public void testPairTransform() { JavaPairDStream sorted = pairStream.transformToPair( new Function, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD in) throws Exception { + public JavaPairRDD call(JavaPairRDD in) { return in.sortByKey(); } }); @@ -1444,15 +1440,15 @@ public JavaPairRDD call(JavaPairRDD in) thro public void testPairToNormalRDDTransform() { List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); List> expected = Arrays.asList( Arrays.asList(3,1,4,2), @@ -1465,11 +1461,11 @@ public void testPairToNormalRDDTransform() { JavaDStream firstParts = pairStream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD in) throws Exception { + public JavaRDD call(JavaPairRDD in) { return in.map(new Function, Integer>() { @Override - public Integer call(Tuple2 in) { - return in._1(); + public Integer call(Tuple2 in2) { + return in2._1(); } }); } @@ -1487,14 +1483,14 @@ public void testMapValues() { List>> inputData = stringStringKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "DODGERS"), - new Tuple2("california", "GIANTS"), - new Tuple2("new york", "YANKEES"), - new Tuple2("new york", "METS")), - Arrays.asList(new Tuple2("california", "SHARKS"), - new Tuple2("california", "DUCKS"), - new Tuple2("new york", "RANGERS"), - new Tuple2("new york", "ISLANDERS"))); + Arrays.asList(new Tuple2<>("california", "DODGERS"), + new Tuple2<>("california", "GIANTS"), + new Tuple2<>("new york", "YANKEES"), + new Tuple2<>("new york", "METS")), + Arrays.asList(new Tuple2<>("california", "SHARKS"), + new Tuple2<>("california", "DUCKS"), + new Tuple2<>("new york", "RANGERS"), + new Tuple2<>("new york", "ISLANDERS"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1502,8 +1498,8 @@ public void testMapValues() { JavaPairDStream mapped = pairStream.mapValues(new Function() { @Override - public String call(String s) throws Exception { - return s.toUpperCase(); + public String call(String s) { + return s.toUpperCase(Locale.ENGLISH); } }); @@ -1519,22 +1515,22 @@ public void testFlatMapValues() { List>> inputData = stringStringKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers1"), - new Tuple2("california", "dodgers2"), - new Tuple2("california", "giants1"), - new Tuple2("california", "giants2"), - new Tuple2("new york", "yankees1"), - new Tuple2("new york", "yankees2"), - new Tuple2("new york", "mets1"), - new Tuple2("new york", "mets2")), - Arrays.asList(new Tuple2("california", "sharks1"), - new Tuple2("california", "sharks2"), - new Tuple2("california", "ducks1"), - new Tuple2("california", "ducks2"), - new Tuple2("new york", "rangers1"), - new Tuple2("new york", "rangers2"), - new Tuple2("new york", "islanders1"), - new Tuple2("new york", "islanders2"))); + Arrays.asList(new Tuple2<>("california", "dodgers1"), + new Tuple2<>("california", "dodgers2"), + new Tuple2<>("california", "giants1"), + new Tuple2<>("california", "giants2"), + new Tuple2<>("new york", "yankees1"), + new Tuple2<>("new york", "yankees2"), + new Tuple2<>("new york", "mets1"), + new Tuple2<>("new york", "mets2")), + Arrays.asList(new Tuple2<>("california", "sharks1"), + new Tuple2<>("california", "sharks2"), + new Tuple2<>("california", "ducks1"), + new Tuple2<>("california", "ducks2"), + new Tuple2<>("new york", "rangers1"), + new Tuple2<>("new york", "rangers2"), + new Tuple2<>("new york", "islanders1"), + new Tuple2<>("new york", "islanders2"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1545,7 +1541,7 @@ public void testFlatMapValues() { new Function>() { @Override public Iterable call(String in) { - List out = new ArrayList(); + List out = new ArrayList<>(); out.add(in + "1"); out.add(in + "2"); return out; @@ -1562,29 +1558,29 @@ public Iterable call(String in) { @Test public void testCoGroup() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + Arrays.asList(new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List, List>>>> expected = Arrays.asList( Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("dodgers"), Arrays.asList("giants"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("yankees"), Arrays.asList("mets")))), + new Tuple2<>("california", + new Tuple2<>(Arrays.asList("dodgers"), Arrays.asList("giants"))), + new Tuple2<>("new york", + new Tuple2<>(Arrays.asList("yankees"), Arrays.asList("mets")))), Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("sharks"), Arrays.asList("ducks"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); + new Tuple2<>("california", + new Tuple2<>(Arrays.asList("sharks"), Arrays.asList("ducks"))), + new Tuple2<>("new york", + new Tuple2<>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( @@ -1620,29 +1616,29 @@ public void testCoGroup() { @Test public void testJoin() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + Arrays.asList(new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), + new Tuple2<>("california", + new Tuple2<>("dodgers", "giants")), + new Tuple2<>("new york", + new Tuple2<>("yankees", "mets"))), Arrays.asList( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); + new Tuple2<>("california", + new Tuple2<>("sharks", "ducks")), + new Tuple2<>("new york", + new Tuple2<>("rangers", "islanders")))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( @@ -1664,13 +1660,13 @@ public void testJoin() { @Test public void testLeftOuterJoin() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks") )); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks") )); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants") ), - Arrays.asList(new Tuple2("new york", "islanders") ) + Arrays.asList(new Tuple2<>("california", "giants") ), + Arrays.asList(new Tuple2<>("new york", "islanders") ) ); @@ -1713,7 +1709,7 @@ public void testCheckpointMasterRecovery() throws InterruptedException { JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -1752,6 +1748,7 @@ public void testContextGetOrCreate() throws InterruptedException { // (used to detect the new context) final AtomicBoolean newContextCreated = new AtomicBoolean(false); Function0 creatingFunc = new Function0() { + @Override public JavaStreamingContext call() { newContextCreated.set(true); return new JavaStreamingContext(conf, Seconds.apply(1)); @@ -1765,20 +1762,20 @@ public JavaStreamingContext call() { newContextCreated.set(false); ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration(), true); + new Configuration(), true); Assert.assertTrue("new context not created", newContextCreated.get()); ssc.stop(); newContextCreated.set(false); ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); + new Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); newContextCreated.set(false); JavaSparkContext sc = new JavaSparkContext(conf); ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); + new Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); } @@ -1800,7 +1797,7 @@ public void testCheckpointofIndividualStream() throws InterruptedException { JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -1818,29 +1815,26 @@ public Integer call(String s) throws Exception { // InputStream functionality is deferred to the existing Scala tests. @Test public void testSocketTextStream() { - JavaReceiverInputDStream test = ssc.socketTextStream("localhost", 12345); + ssc.socketTextStream("localhost", 12345); } @Test public void testSocketString() { - - class Converter implements Function> { - public Iterable call(InputStream in) throws IOException { - BufferedReader reader = new BufferedReader(new InputStreamReader(in)); - List out = new ArrayList(); - while (true) { - String line = reader.readLine(); - if (line == null) { break; } - out.add(line); - } - return out; - } - } - - JavaDStream test = ssc.socketStream( + ssc.socketStream( "localhost", 12345, - new Converter(), + new Function>() { + @Override + public Iterable call(InputStream in) throws IOException { + List out = new ArrayList<>(); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(in))) { + for (String line; (line = reader.readLine()) != null;) { + out.add(line); + } + } + return out; + } + }, StorageLevel.MEMORY_ONLY()); } @@ -1870,7 +1864,7 @@ public void testFileStream() throws IOException { TextInputFormat.class, new Function() { @Override - public Boolean call(Path v1) throws Exception { + public Boolean call(Path v1) { return Boolean.TRUE; } }, @@ -1879,7 +1873,7 @@ public Boolean call(Path v1) throws Exception { JavaDStream test = inputStream.map( new Function, String>() { @Override - public String call(Tuple2 v1) throws Exception { + public String call(Tuple2 v1) { return v1._2().toString(); } }); @@ -1892,19 +1886,15 @@ public String call(Tuple2 v1) throws Exception { @Test public void testRawSocketStream() { - JavaReceiverInputDStream test = ssc.rawSocketStream("localhost", 12345); + ssc.rawSocketStream("localhost", 12345); } - private List> fileTestPrepare(File testDir) throws IOException { + private static List> fileTestPrepare(File testDir) throws IOException { File existingFile = new File(testDir, "0"); Files.write("0\n", existingFile, Charset.forName("UTF-8")); - assertTrue(existingFile.setLastModified(1000) && existingFile.lastModified() == 1000); - - List> expected = Arrays.asList( - Arrays.asList("0") - ); - - return expected; + Assert.assertTrue(existingFile.setLastModified(1000)); + Assert.assertEquals(1000, existingFile.lastModified()); + return Arrays.asList(Arrays.asList("0")); } @SuppressWarnings("unchecked") diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index 1b0787fe69dec..ec2bffd6a5b97 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -36,7 +36,6 @@ import java.io.Serializable; import java.net.ConnectException; import java.net.Socket; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; public class JavaReceiverAPISuite implements Serializable { @@ -64,16 +63,16 @@ public void testReceiver() throws InterruptedException { ssc.receiverStream(new JavaSocketReceiver("localhost", server.port())); JavaDStream mapped = input.map(new Function() { @Override - public String call(String v1) throws Exception { + public String call(String v1) { return v1 + "."; } }); mapped.foreachRDD(new Function, Void>() { @Override - public Void call(JavaRDD rdd) throws Exception { - long count = rdd.count(); - dataCounter.addAndGet(count); - return null; + public Void call(JavaRDD rdd) { + long count = rdd.count(); + dataCounter.addAndGet(count); + return null; } }); @@ -83,7 +82,7 @@ public Void call(JavaRDD rdd) throws Exception { Thread.sleep(200); for (int i = 0; i < 6; i++) { - server.send("" + i + "\n"); // \n to make sure these are separate lines + server.send(i + "\n"); // \n to make sure these are separate lines Thread.sleep(100); } while (dataCounter.get() == 0 && System.currentTimeMillis() - startTime < timeout) { @@ -95,50 +94,49 @@ public Void call(JavaRDD rdd) throws Exception { server.stop(); } } -} -class JavaSocketReceiver extends Receiver { + private static class JavaSocketReceiver extends Receiver { - String host = null; - int port = -1; + String host = null; + int port = -1; - public JavaSocketReceiver(String host_ , int port_) { - super(StorageLevel.MEMORY_AND_DISK()); - host = host_; - port = port_; - } + JavaSocketReceiver(String host_ , int port_) { + super(StorageLevel.MEMORY_AND_DISK()); + host = host_; + port = port_; + } - @Override - public void onStart() { - new Thread() { - @Override public void run() { - receive(); - } - }.start(); - } + @Override + public void onStart() { + new Thread() { + @Override public void run() { + receive(); + } + }.start(); + } - @Override - public void onStop() { - } + @Override + public void onStop() { + } - private void receive() { - Socket socket = null; - try { - socket = new Socket(host, port); - BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); - String userInput; - while ((userInput = in.readLine()) != null) { - store(userInput); + private void receive() { + try { + Socket socket = new Socket(host, port); + BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); + String userInput; + while ((userInput = in.readLine()) != null) { + store(userInput); + } + in.close(); + socket.close(); + } catch(ConnectException ce) { + ce.printStackTrace(); + restart("Could not connect", ce); + } catch(Throwable t) { + t.printStackTrace(); + restart("Error receiving data", t); } - in.close(); - socket.close(); - } catch(ConnectException ce) { - ce.printStackTrace(); - restart("Could not connect", ce); - } catch(Throwable t) { - t.printStackTrace(); - restart("Error receiving data", t); } } -} +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 255376807c957..9d296c6d3ef8b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -191,6 +191,20 @@ class BasicOperationsSuite extends TestSuiteBase { ) } + test("union with input stream return None") { + val input = Seq(1 to 4, 101 to 104, 201 to 204, null) + val output = Seq(1 to 8, 101 to 108, 201 to 208) + intercept[SparkException] { + testOperation( + input, + (s: DStream[Int]) => s.union(s.map(_ + 4)), + output, + input.length, + false + ) + } + } + test("StreamingContext.union") { val input = Seq(1 to 4, 101 to 104, 201 to 204) val output = Seq(1 to 12, 101 to 112, 201 to 212) @@ -211,6 +225,32 @@ class BasicOperationsSuite extends TestSuiteBase { ) } + test("transform with NULL") { + val input = Seq(1 to 4) + intercept[SparkException] { + testOperation( + input, + (r: DStream[Int]) => r.transform(rdd => null.asInstanceOf[RDD[Int]]), + Seq(Seq()), + 1, + false + ) + } + } + + test("transform with input stream return None") { + val input = Seq(1 to 4, 5 to 8, null) + intercept[SparkException] { + testOperation( + input, + (r: DStream[Int]) => r.transform(rdd => rdd.map(_.toString)), + input.filterNot(_ == null).map(_.map(_.toString)), + input.length, + false + ) + } + } + test("transformWith") { val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() ) val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") ) @@ -231,6 +271,27 @@ class BasicOperationsSuite extends TestSuiteBase { testOperation(inputData1, inputData2, operation, outputData, true) } + test("transformWith with input stream return None") { + val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), null ) + val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), null ) + val outputData = Seq( + Seq("a", "b", "a", "b"), + Seq("a", "b", "", ""), + Seq("") + ) + + val operation = (s1: DStream[String], s2: DStream[String]) => { + s1.transformWith( // RDD.join in transform + s2, + (rdd1: RDD[String], rdd2: RDD[String]) => rdd1.union(rdd2) + ) + } + + intercept[SparkException] { + testOperation(inputData1, inputData2, operation, outputData, inputData1.length, true) + } + } + test("StreamingContext.transform") { val input = Seq(1 to 4, 101 to 104, 201 to 204) val output = Seq(1 to 12, 101 to 112, 201 to 212) @@ -247,6 +308,24 @@ class BasicOperationsSuite extends TestSuiteBase { testOperation(input, operation, output) } + test("StreamingContext.transform with input stream return None") { + val input = Seq(1 to 4, 101 to 104, 201 to 204, null) + val output = Seq(1 to 12, 101 to 112, 201 to 212) + + // transform over 3 DStreams by doing union of the 3 RDDs + val operation = (s: DStream[Int]) => { + s.context.transform( + Seq(s, s.map(_ + 4), s.map(_ + 8)), // 3 DStreams + (rdds: Seq[RDD[_]], time: Time) => + rdds.head.context.union(rdds.map(_.asInstanceOf[RDD[Int]])) // union of RDDs + ) + } + + intercept[SparkException] { + testOperation(input, operation, output, input.length, false) + } + } + test("cogroup") { val inputData1 = Seq( Seq("a", "a", "b"), Seq("a", ""), Seq(""), Seq() ) val inputData2 = Seq( Seq("a", "a", "b"), Seq("b", ""), Seq(), Seq() ) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 1bba7a143edf2..a6956533c07a5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -408,10 +408,14 @@ class CheckpointSuite extends TestSuiteBase { ssc = new StreamingContext(checkpointDir) ssc.start() - val outputNew = advanceTimeWithRealDelay(ssc, 2) eventually(timeout(10.seconds)) { assert(RateTestReceiver.getActive().nonEmpty) + } + + advanceTimeWithRealDelay(ssc, 2) + + eventually(timeout(10.seconds)) { assert(RateTestReceiver.getActive().get.getDefaultBlockGeneratorRateLimit() === 200) } ssc.stop() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index ec2852d9a0206..047e38ef90998 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -76,6 +76,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { fail("Timeout: cannot finish all batches in 30 seconds") } + // Ensure progress listener has been notified of all events + ssc.scheduler.listenerBus.waitUntilEmpty(500) + // Verify all "InputInfo"s have been reported assert(ssc.progressListener.numTotalReceivedRecords === input.size) assert(ssc.progressListener.numTotalProcessedRecords === input.size) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 6c0c926755c20..b2b6848719639 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -29,7 +29,8 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ -import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.memory.StaticMemoryManager +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer @@ -47,7 +48,9 @@ class ReceivedBlockHandlerSuite with Matchers with Logging { - val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") + val conf = new SparkConf() + .set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") + .set("spark.app.id", "streaming-test") val hadoopConf = new Configuration() val streamId = 1 val securityMgr = new SecurityManager(conf) @@ -184,7 +187,7 @@ class ReceivedBlockHandlerSuite } test("Test Block - isFullyConsumed") { - val sparkConf = new SparkConf() + val sparkConf = new SparkConf().set("spark.app.id", "streaming-test") sparkConf.set("spark.storage.unrollMemoryThreshold", "512") // spark.storage.unrollFraction set to 0.4 for BlockManager sparkConf.set("spark.storage.unrollFraction", "0.4") @@ -251,12 +254,14 @@ class ReceivedBlockHandlerSuite maxMem: Long, conf: SparkConf, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val transfer = new NioBlockTransferService(conf, securityMgr) - val manager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer, securityMgr, 0) - manager.initialize("app-id") - blockManagerBuffer += manager - manager + val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem) + val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) + val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, conf, + memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + memManager.setMemoryStore(blockManager.memoryStore) + blockManager.initialize("app-id") + blockManagerBuffer += blockManager + blockManager } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 7423ef6bcb6ea..c7a877142b374 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -30,7 +30,7 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source import org.apache.spark.storage.StorageLevel @@ -180,6 +180,38 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.scheduler.isStarted === false) } + test("start should set job group and description of streaming jobs correctly") { + ssc = new StreamingContext(conf, batchDuration) + ssc.sc.setJobGroup("non-streaming", "non-streaming", true) + val sc = ssc.sc + + @volatile var jobGroupFound: String = "" + @volatile var jobDescFound: String = "" + @volatile var jobInterruptFound: String = "" + @volatile var allFound: Boolean = false + + addInputStream(ssc).foreachRDD { rdd => + jobGroupFound = sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) + jobDescFound = sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) + jobInterruptFound = sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) + allFound = true + } + ssc.start() + + eventually(timeout(10 seconds), interval(10 milliseconds)) { + assert(allFound === true) + } + + // Verify streaming jobs have expected thread-local properties + assert(jobGroupFound === null) + assert(jobDescFound.contains("Streaming job from")) + assert(jobInterruptFound === "false") + + // Verify current thread's thread-local properties have not changed + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "non-streaming") + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) === "non-streaming") + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) === "true") + } test("start multiple times") { ssc = new StreamingContext(master, appName, batchDuration) @@ -726,16 +758,26 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo } test("queueStream doesn't support checkpointing") { - val checkpointDir = Utils.createTempDir() - ssc = new StreamingContext(master, appName, batchDuration) - val rdd = ssc.sparkContext.parallelize(1 to 10) - ssc.queueStream[Int](Queue(rdd)).print() - ssc.checkpoint(checkpointDir.getAbsolutePath) - val e = intercept[NotSerializableException] { - ssc.start() + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + def creatingFunction(): StreamingContext = { + val _ssc = new StreamingContext(conf, batchDuration) + val rdd = _ssc.sparkContext.parallelize(1 to 10) + _ssc.checkpoint(checkpointDirectory) + _ssc.queueStream[Int](Queue(rdd)).register() + _ssc + } + ssc = StreamingContext.getOrCreate(checkpointDirectory, creatingFunction _) + ssc.start() + eventually(timeout(10000 millis)) { + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1) + } + ssc.stop() + val e = intercept[SparkException] { + ssc = StreamingContext.getOrCreate(checkpointDirectory, creatingFunction _) } // StreamingContext.validate changes the message, so use "contains" here - assert(e.getMessage.contains("queueStream doesn't support checkpointing")) + assert(e.getCause.getMessage.contains("queueStream doesn't support checkpointing. " + + "Please don't use queueStream when checkpointing is enabled.")) } def addInputStream(s: StreamingContext): DStream[Int] = { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index d840c349bbbc4..5dc0472c7770c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedBuffer, SynchronizedMap} import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global @@ -140,6 +140,90 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { } } + test("output operation reporting") { + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) + inputStream.foreachRDD(_.count()) + inputStream.foreachRDD(_.collect()) + inputStream.foreachRDD(_.count()) + + val collector = new OutputOperationInfoCollector + ssc.addStreamingListener(collector) + + ssc.start() + try { + eventually(timeout(30 seconds), interval(20 millis)) { + collector.startedOutputOperationIds.take(3) should be (Seq(0, 1, 2)) + collector.completedOutputOperationIds.take(3) should be (Seq(0, 1, 2)) + } + } finally { + ssc.stop() + } + } + + test("onBatchCompleted with successful batch") { + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) + inputStream.foreachRDD(_.count) + + val failureReasons = startStreamingContextAndCollectFailureReasons(ssc) + assert(failureReasons != null && failureReasons.isEmpty, + "A successful batch should not set errorMessage") + } + + test("onBatchCompleted with failed batch and one failed job") { + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) + inputStream.foreachRDD { _ => + throw new RuntimeException("This is a failed job") + } + + // Check if failureReasons contains the correct error message + val failureReasons = startStreamingContextAndCollectFailureReasons(ssc, isFailed = true) + assert(failureReasons != null) + assert(failureReasons.size === 1) + assert(failureReasons.contains(0)) + assert(failureReasons(0).contains("This is a failed job")) + } + + test("onBatchCompleted with failed batch and multiple failed jobs") { + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) + inputStream.foreachRDD { _ => + throw new RuntimeException("This is a failed job") + } + inputStream.foreachRDD { _ => + throw new RuntimeException("This is another failed job") + } + + // Check if failureReasons contains the correct error messages + val failureReasons = + startStreamingContextAndCollectFailureReasons(ssc, isFailed = true) + assert(failureReasons != null) + assert(failureReasons.size === 2) + assert(failureReasons.contains(0)) + assert(failureReasons.contains(1)) + assert(failureReasons(0).contains("This is a failed job")) + assert(failureReasons(1).contains("This is another failed job")) + } + + private def startStreamingContextAndCollectFailureReasons( + _ssc: StreamingContext, isFailed: Boolean = false): Map[Int, String] = { + val failureReasonsCollector = new FailureReasonsCollector() + _ssc.addStreamingListener(failureReasonsCollector) + val batchCounter = new BatchCounter(_ssc) + _ssc.start() + // Make sure running at least one batch + batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000) + if (isFailed) { + intercept[RuntimeException] { + _ssc.awaitTerminationOrTimeout(10000) + } + } + _ssc.stop() + failureReasonsCollector.failureReasons.toMap + } + /** Check if a sequence of numbers is in increasing order */ def isInIncreasingOrder(seq: Seq[Long]): Boolean = { for (i <- 1 until seq.size) { @@ -191,6 +275,22 @@ class ReceiverInfoCollector extends StreamingListener { } } +/** Listener that collects information on processed output operations */ +class OutputOperationInfoCollector extends StreamingListener { + val startedOutputOperationIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] + val completedOutputOperationIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] + + override def onOutputOperationStarted( + outputOperationStarted: StreamingListenerOutputOperationStarted): Unit = { + startedOutputOperationIds += outputOperationStarted.outputOperationInfo.id + } + + override def onOutputOperationCompleted( + outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = { + completedOutputOperationIds += outputOperationCompleted.outputOperationInfo.id + } +} + class StreamingListenerSuiteReceiver extends Receiver[Any](StorageLevel.MEMORY_ONLY) with Logging { def onStart() { Future { @@ -205,3 +305,18 @@ class StreamingListenerSuiteReceiver extends Receiver[Any](StorageLevel.MEMORY_O } def onStop() { } } + +/** + * A StreamingListener that saves all latest `failureReasons` in a batch. + */ +class FailureReasonsCollector extends StreamingListener { + + val failureReasons = new HashMap[Int, String] with SynchronizedMap[Int, String] + + override def onOutputOperationCompleted( + outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = { + outputOperationCompleted.outputOperationInfo.failureReason.foreach { f => + failureReasons(outputOperationCompleted.outputOperationInfo.id) = f + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 068a6cb0e8fa4..a5744a9009c1c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -117,11 +117,11 @@ class UISeleniumSuite findAll(cssSelector("""#active-batches-table th""")).map(_.text).toSeq should be { List("Batch Time", "Input Size", "Scheduling Delay (?)", "Processing Time (?)", - "Status") + "Output Ops: Succeeded/Total", "Status") } findAll(cssSelector("""#completed-batches-table th""")).map(_.text).toSeq should be { List("Batch Time", "Input Size", "Scheduling Delay (?)", "Processing Time (?)", - "Total Delay (?)") + "Total Delay (?)", "Output Ops: Succeeded/Total") } val batchLinks = @@ -138,7 +138,7 @@ class UISeleniumSuite summaryText should contain ("Total delay:") findAll(cssSelector("""#batch-job-table th""")).map(_.text).toSeq should be { - List("Output Op Id", "Description", "Duration", "Job Id", "Duration", + List("Output Op Id", "Description", "Duration", "Status", "Job Id", "Duration", "Stages: Succeeded/Total", "Tasks (for all stages): Succeeded/Total", "Error") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala index a38cc603f2190..2f11b255f1104 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -184,9 +184,10 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { // Verify that the final data is present in the final generated block and // pushed before complete stop assert(blockGenerator.isStopped() === false) // generator has not stopped yet - clock.advance(blockIntervalMs) // force block generation - failAfter(1 second) { - thread.join() + eventually(timeout(10 seconds), interval(10 milliseconds)) { + // Keep calling `advance` to avoid blocking forever in `clock.waitTillTime` + clock.advance(blockIntervalMs) + assert(thread.isAlive === false) } assert(blockGenerator.isStopped() === true) // generator has finally been completely stopped assert(listener.pushedData === data, "All data not pushed by stop()") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index 45138b748ecab..fda86aef457d4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -22,6 +22,8 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart, TaskLocality} +import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream.ReceiverInputDStream @@ -80,6 +82,28 @@ class ReceiverTrackerSuite extends TestSuiteBase { } } } + + test("SPARK-11063: TaskSetManager should use Receiver RDD's preferredLocations") { + // Use ManualClock to prevent from starting batches so that we can make sure the only task is + // for starting the Receiver + val _conf = conf.clone.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") + withStreamingContext(new StreamingContext(_conf, Milliseconds(100))) { ssc => + @volatile var receiverTaskLocality: TaskLocality = null + ssc.sparkContext.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + receiverTaskLocality = taskStart.taskInfo.taskLocality + } + }) + val input = ssc.receiverStream(new TestReceiver) + val output = new TestOutputStream(input) + output.register() + ssc.start() + eventually(timeout(10 seconds), interval(10 millis)) { + // If preferredLocations is set correctly, receiverTaskLocality should be NODE_LOCAL + assert(receiverTaskLocality === TaskLocality.NODE_LOCAL) + } + } + } } /** An input DStream with for testing rate controlling */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 995f1197ccdfd..af4718b4eb705 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -63,7 +63,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { 1 -> StreamInputInfo(1, 300L, Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "test"))) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None, Map.empty) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted))) listener.runningBatches should be (Nil) @@ -75,7 +75,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (0) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + val batchInfoStarted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None, Map.empty) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) @@ -116,7 +117,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { OutputOpIdAndSparkJobId(1, 1)) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + val batchInfoCompleted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None, Map.empty) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (Nil) @@ -156,7 +158,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + val batchInfoCompleted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None, Map.empty) for(_ <- 0 until (limit + 10)) { listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) @@ -173,8 +176,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { // fulfill completedBatchInfos for(i <- 0 until limit) { - val batchInfoCompleted = - BatchInfo(Time(1000 + i * 100), Map.empty, 1000 + i * 100, Some(2000 + i * 100), None) + val batchInfoCompleted = BatchInfo( + Time(1000 + i * 100), Map.empty, 1000 + i * 100, Some(2000 + i * 100), None, Map.empty) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) val jobStart = createJobStart(Time(1000 + i * 100), outputOpId = 0, jobId = 1) listener.onJobStart(jobStart) @@ -185,7 +188,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.onJobStart(jobStart) val batchInfoSubmitted = - BatchInfo(Time(1000 + limit * 100), Map.empty, (1000 + limit * 100), None, None) + BatchInfo(Time(1000 + limit * 100), Map.empty, (1000 + limit * 100), None, None, Map.empty) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) // We still can see the info retrieved from onJobStart @@ -201,8 +204,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { // A lot of "onBatchCompleted"s happen before "onJobStart" for(i <- limit + 1 to limit * 2) { - val batchInfoCompleted = - BatchInfo(Time(1000 + i * 100), Map.empty, 1000 + i * 100, Some(2000 + i * 100), None) + val batchInfoCompleted = BatchInfo( + Time(1000 + i * 100), Map.empty, 1000 + i * 100, Some(2000 + i * 100), None, Map.empty) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) } @@ -227,11 +230,13 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val streamIdToInputInfo = Map(0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L)) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None) + val batchInfoSubmitted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None, Map.empty) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + val batchInfoStarted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None, Map.empty) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) // onJobStart @@ -248,7 +253,8 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.onJobStart(jobStart4) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None) + val batchInfoCompleted = + BatchInfo(Time(1000), streamIdToInputInfo, 1000, Some(2000), None, Map.empty) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala new file mode 100644 index 0000000000000..0544972d95c03 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RecurringTimerSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import scala.collection.mutable +import scala.concurrent.duration._ + +import org.scalatest.PrivateMethodTester +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.ManualClock + +class RecurringTimerSuite extends SparkFunSuite with PrivateMethodTester { + + test("basic") { + val clock = new ManualClock() + val results = new mutable.ArrayBuffer[Long]() with mutable.SynchronizedBuffer[Long] + val timer = new RecurringTimer(clock, 100, time => { + results += time + }, "RecurringTimerSuite-basic") + timer.start(0) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(results === Seq(0L)) + } + clock.advance(100) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(results === Seq(0L, 100L)) + } + clock.advance(200) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(results === Seq(0L, 100L, 200L, 300L)) + } + assert(timer.stop(interruptTimer = true) === 300L) + } + + test("SPARK-10224: call 'callback' after stopping") { + val clock = new ManualClock() + val results = new mutable.ArrayBuffer[Long]() with mutable.SynchronizedBuffer[Long] + val timer = new RecurringTimer(clock, 100, time => { + results += time + }, "RecurringTimerSuite-SPARK-10224") + timer.start(0) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(results === Seq(0L)) + } + @volatile var lastTime = -1L + // Now RecurringTimer is waiting for the next interval + val thread = new Thread { + override def run(): Unit = { + lastTime = timer.stop(interruptTimer = false) + } + } + thread.start() + val stopped = PrivateMethod[RecurringTimer]('stopped) + // Make sure the `stopped` field has been changed + eventually(timeout(10.seconds), interval(10.millis)) { + assert(timer.invokePrivate(stopped()) === true) + } + clock.advance(200) + // When RecurringTimer is awake from clock.waitTillTime, it will call `callback` once. + // Then it will find `stopped` is true and exit the loop, but it should call `callback` again + // before exiting its internal thread. + thread.join() + assert(results === Seq(0L, 100L, 200L)) + assert(lastTime === 200L) + } +} diff --git a/tags/pom.xml b/tags/pom.xml new file mode 100644 index 0000000000000..ca93722e73345 --- /dev/null +++ b/tags/pom.xml @@ -0,0 +1,50 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.6.0-SNAPSHOT + ../pom.xml + + + org.apache.spark + spark-test-tags_2.10 + jar + Spark Project Test Tags + http://spark.apache.org/ + + test-tags + + + + + org.scalatest + scalatest_${scala.binary.version} + compile + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java b/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java new file mode 100644 index 0000000000000..1b0c416b0fe4e --- /dev/null +++ b/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.tags; + +import java.lang.annotation.*; +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface ExtendedHiveTest { } diff --git a/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java b/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java new file mode 100644 index 0000000000000..2a631bfc88cf0 --- /dev/null +++ b/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.tags; + +import java.lang.annotation.*; +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface ExtendedYarnTest { } diff --git a/tools/pom.xml b/tools/pom.xml index 298ee2348b58e..1e64f280e5bed 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 89475ee3cf5a1..caf1f77890b58 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -56,14 +56,8 @@ - junit - junit - test - - - com.novocode - junit-interface - test + org.apache.spark + spark-test-tags_${scala.binary.version} org.mockito diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java index c08c9c73d2396..3ced2094f5e6b 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -19,7 +19,11 @@ import org.apache.spark.unsafe.Platform; -public class ByteArray { +import java.util.Arrays; + +public final class ByteArray { + + public static final byte[] EMPTY_BYTE = new byte[0]; /** * Writes the content of a byte array into a memory address, identified by an object and an @@ -29,4 +33,45 @@ public class ByteArray { public static void writeToMemory(byte[] src, Object target, long targetOffset) { Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET, target, targetOffset, src.length); } + + /** + * Returns a 64-bit integer that can be used as the prefix used in sorting. + */ + public static long getPrefix(byte[] bytes) { + if (bytes == null) { + return 0L; + } else { + final int minLen = Math.min(bytes.length, 8); + long p = 0; + for (int i = 0; i < minLen; ++i) { + p |= (128L + Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i)) + << (56 - 8 * i); + } + return p; + } + } + + public static byte[] subStringSQL(byte[] bytes, int pos, int len) { + // This pos calculation is according to UTF8String#subStringSQL + if (pos > bytes.length) { + return EMPTY_BYTE; + } + int start = 0; + int end; + if (pos > 0) { + start = pos - 1; + } else if (pos < 0) { + start = bytes.length + pos; + } + if ((bytes.length - start) < len) { + end = bytes.length; + } else { + end = start + len; + } + start = Math.max(start, 0); // underflow + if (start >= end) { + return EMPTY_BYTE; + } + return Arrays.copyOfRange(bytes, start, end); + } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 216aeea60d1c8..b7aecb5102ba6 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -19,6 +19,7 @@ import javax.annotation.Nonnull; import java.io.*; +import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; import java.util.Map; @@ -137,6 +138,15 @@ public void writeToMemory(Object target, long targetOffset) { Platform.copyMemory(base, offset, target, targetOffset, numBytes); } + public void writeTo(ByteBuffer buffer) { + assert(buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + numBytes); + } + /** * Returns the number of bytes for a code point with the first byte as `b` * @param b The first byte of a code point diff --git a/yarn/pom.xml b/yarn/pom.xml index f6737695307a2..3eadacba13e18 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml @@ -51,6 +51,10 @@ test-jar test + + org.apache.spark + spark-test-tags_${scala.binary.version} + org.apache.hadoop hadoop-yarn-api diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 991b5cec00bd8..4b4d9990ce9f9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -62,10 +62,21 @@ private[spark] class ApplicationMaster( .asInstanceOf[YarnConfiguration] private val isClusterMode = args.userClass != null - // Default to numExecutors * 2, with minimum of 3 - private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", - sparkConf.getInt("spark.yarn.max.worker.failures", - math.max(sparkConf.getInt("spark.executor.instances", 0) * 2, 3))) + // Default to twice the number of executors (twice the maximum number of executors if dynamic + // allocation is enabled), with a minimum of 3. + + private val maxNumExecutorFailures = { + val defaultKey = + if (Utils.isDynamicAllocationEnabled(sparkConf)) { + "spark.dynamicAllocation.maxExecutors" + } else { + "spark.executor.instances" + } + val effectiveNumExecutors = sparkConf.getInt(defaultKey, 0) + val defaultMaxNumExecutorFailures = math.max(3, 2 * effectiveNumExecutors) + + sparkConf.getInt("spark.yarn.max.executor.failures", defaultMaxNumExecutorFailures) + } @volatile private var exitCode = 0 @volatile private var unregistered = false @@ -255,7 +266,6 @@ private[spark] class ApplicationMaster( driverRef, yarnConf, _sparkConf, - if (sc != null) sc.preferredNodeLocationData else Map(), uiAddress, historyAddress, securityMgr) @@ -345,7 +355,7 @@ private[spark] class ApplicationMaster( if (allocator.getNumExecutorsFailed >= maxNumExecutorFailures) { finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_MAX_EXECUTOR_FAILURES, - "Max number of executor failures reached") + s"Max number of executor failures ($maxNumExecutorFailures) reached") } else { logDebug("Sending progress") allocator.allocateResources() @@ -558,13 +568,15 @@ private[spark] class ApplicationMaster( override def onStart(): Unit = { driver.send(RegisterClusterManager(self)) - } override def receive: PartialFunction[Any, Unit] = { case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") driver.send(x) + + case DriverHello => + // SPARK-10987: no action needed for this message. } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -590,6 +602,13 @@ private[spark] class ApplicationMaster( case None => logWarning("Container allocator is not ready to kill executors yet.") } context.reply(true) + + case GetExecutorLossReason(eid) => + Option(allocator) match { + case Some(a) => a.enqueueGetLossReasonRequest(eid, context) + case None => logWarning(s"Container allocator is not ready to find" + + s" executor loss reasons yet.") + } } override def onDisconnected(remoteAddress: RpcAddress): Unit = { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index b08412414aa1c..17d9943c795e3 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -105,9 +105,9 @@ class ApplicationMasterArguments(val args: Array[String]) { | place on the PYTHONPATH for Python apps. | --args ARGS Arguments to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. - | --num-executors NUM Number of executors to start (Default: 2) | --executor-cores NUM Number of cores for the executors (Default: 1) | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) + | --properties-file FILE Path to a custom Spark properties file. """.stripMargin) // scalastyle:on println System.exit(exitCode) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index e9a02baafd28e..4954b6180902e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -54,8 +54,9 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils private[spark] class Client( @@ -69,8 +70,6 @@ private[spark] class Client( def this(clientArgs: ClientArguments, spConf: SparkConf) = this(clientArgs, SparkHadoopUtil.get.newConfiguration(spConf), spConf) - def this(clientArgs: ClientArguments) = this(clientArgs, new SparkConf()) - private val yarnClient = YarnClient.createYarnClient private val yarnConf = new YarnConfiguration(hadoopConf) private var credentials: Credentials = null @@ -83,10 +82,31 @@ private[spark] class Client( private var principal: String = null private var keytab: String = null + private val launcherBackend = new LauncherBackend() { + override def onStopRequest(): Unit = { + if (isClusterMode && appId != null) { + yarnClient.killApplication(appId) + } else { + setState(SparkAppHandle.State.KILLED) + stop() + } + } + } private val fireAndForget = isClusterMode && !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true) - def stop(): Unit = yarnClient.stop() + private var appId: ApplicationId = null + + def reportLauncherState(state: SparkAppHandle.State): Unit = { + launcherBackend.setState(state) + } + + def stop(): Unit = { + launcherBackend.close() + yarnClient.stop() + // Unset YARN mode system env variable, to allow switching between cluster types. + System.clearProperty("SPARK_YARN_MODE") + } /** * Submit an application running our ApplicationMaster to the ResourceManager. @@ -98,6 +118,7 @@ private[spark] class Client( def submitApplication(): ApplicationId = { var appId: ApplicationId = null try { + launcherBackend.connect() // Setup the credentials before doing anything else, // so we have don't have issues at any point. setupCredentials() @@ -111,6 +132,8 @@ private[spark] class Client( val newApp = yarnClient.createApplication() val newAppResponse = newApp.getNewApplicationResponse() appId = newAppResponse.getApplicationId() + reportLauncherState(SparkAppHandle.State.SUBMITTED) + launcherBackend.setAppId(appId.toString()) // Verify whether the cluster has enough resources for our AM verifyClusterResources(newAppResponse) @@ -185,6 +208,20 @@ private[spark] class Client( case None => logDebug("spark.yarn.maxAppAttempts is not set. " + "Cluster's default value will be used.") } + + if (sparkConf.contains("spark.yarn.am.attemptFailuresValidityInterval")) { + try { + val interval = sparkConf.getTimeAsMs("spark.yarn.am.attemptFailuresValidityInterval") + val method = appContext.getClass().getMethod( + "setAttemptFailuresValidityInterval", classOf[Long]) + method.invoke(appContext, interval: java.lang.Long) + } catch { + case e: NoSuchMethodException => + logWarning("Ignoring spark.yarn.am.attemptFailuresValidityInterval because the version " + + "of YARN does not support it") + } + } + val capability = Records.newRecord(classOf[Resource]) capability.setMemory(args.amMemory + amMemoryOverhead) capability.setVirtualCores(args.amCores) @@ -335,7 +372,8 @@ private[spark] class Client( destName: Option[String] = None, targetDir: Option[String] = None, appMasterOnly: Boolean = false): (Boolean, String) = { - val localURI = new URI(path.trim()) + val trimmedPath = path.trim() + val localURI = Utils.resolveURI(trimmedPath) if (localURI.getScheme != LOCAL_SCHEME) { if (addDistributedUri(localURI)) { val localPath = getQualifiedLocalPath(localURI, hadoopConf) @@ -351,7 +389,7 @@ private[spark] class Client( (false, null) } } else { - (true, path.trim()) + (true, trimmedPath) } } @@ -459,6 +497,19 @@ private[spark] class Client( */ private def createConfArchive(): File = { val hadoopConfFiles = new HashMap[String, File]() + + // Uploading $SPARK_CONF_DIR/log4j.properties file to the distributed cache to make sure that + // the executors will use the latest configurations instead of the default values. This is + // required when user changes log4j.properties directly to set the log configurations. If + // configuration file is provided through --files then executors will be taking configurations + // from --files instead of $SPARK_CONF_DIR/log4j.properties. + val log4jFileName = "log4j.properties" + Option(Utils.getContextOrSparkClassLoader.getResource(log4jFileName)).foreach { url => + if (url.getProtocol == "file") { + hadoopConfFiles(log4jFileName) = new File(url.getPath) + } + } + Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => sys.env.get(envKey).foreach { path => val dir = new File(path) @@ -572,10 +623,10 @@ private[spark] class Client( LOCALIZED_PYTHON_DIR) } (pySparkArchives ++ pyArchives).foreach { path => - val uri = new URI(path) + val uri = Utils.resolveURI(path) if (uri.getScheme != LOCAL_SCHEME) { pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), - new Path(path).getName()) + new Path(uri).getName()) } else { pythonPath += uri.getPath() } @@ -726,6 +777,7 @@ private[spark] class Client( // For log4j configuration to reference javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR) + YarnCommandBuilderUtils.addPermGenSizeOpt(javaOpts) val userClass = if (isClusterMode) { @@ -875,6 +927,20 @@ private[spark] class Client( } } + if (lastState != state) { + state match { + case YarnApplicationState.RUNNING => + reportLauncherState(SparkAppHandle.State.RUNNING) + case YarnApplicationState.FINISHED => + reportLauncherState(SparkAppHandle.State.FINISHED) + case YarnApplicationState.FAILED => + reportLauncherState(SparkAppHandle.State.FAILED) + case YarnApplicationState.KILLED => + reportLauncherState(SparkAppHandle.State.KILLED) + case _ => + } + } + if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { @@ -922,8 +988,8 @@ private[spark] class Client( * throw an appropriate SparkException. */ def run(): Unit = { - val appId = submitApplication() - if (fireAndForget) { + this.appId = submitApplication() + if (!launcherBackend.isConnected() && fireAndForget) { val report = getApplicationReport(appId) val state = report.getYarnApplicationState logInfo(s"Application report for $appId (state: $state)") @@ -955,9 +1021,9 @@ private[spark] class Client( val pyArchivesFile = new File(pyLibPath, "pyspark.zip") require(pyArchivesFile.exists(), "pyspark.zip not found; cannot run pyspark application in YARN mode.") - val py4jFile = new File(pyLibPath, "py4j-0.8.2.1-src.zip") + val py4jFile = new File(pyLibPath, "py4j-0.9-src.zip") require(py4jFile.exists(), - "py4j-0.8.2.1-src.zip not found; cannot run pyspark application in YARN mode.") + "py4j-0.9-src.zip not found; cannot run pyspark application in YARN mode.") Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) } } @@ -965,6 +1031,7 @@ private[spark] class Client( } object Client extends Logging { + def main(argStrings: Array[String]) { if (!sys.props.contains("SPARK_SUBMIT")) { logWarning("WARNING: This client is deprecated and will be removed in a " + @@ -1045,7 +1112,10 @@ object Client extends Logging { s"in favor of the $CONF_SPARK_JAR configuration variable.") System.getenv(ENV_SPARK_JAR) } else { - SparkContext.jarOfClass(this.getClass).head + SparkContext.jarOfClass(this.getClass).getOrElse(throw new SparkException("Could not " + + "find jar containing Spark classes. The jar can be defined using the " + + "spark.yarn.jar configuration option. If testing Spark, either set that option or " + + "make sure SPARK_PREPEND_CLASSES is not set.")) } } @@ -1146,17 +1216,28 @@ object Client extends Logging { } if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) { - val userClassPath = + // in order to properly add the app jar when user classpath is first + // we have to do the mainJar separate in order to send the right thing + // into addFileToClasspath + val mainJar = + if (args != null) { + getMainJarUri(Option(args.userJar)) + } else { + getMainJarUri(sparkConf.getOption(CONF_SPARK_USER_JAR)) + } + mainJar.foreach(addFileToClasspath(sparkConf, conf, _, APP_JAR, env)) + + val secondaryJars = if (args != null) { - getUserClasspath(Option(args.userJar), Option(args.addJars)) + getSecondaryJarUris(Option(args.addJars)) } else { - getUserClasspath(sparkConf) + getSecondaryJarUris(sparkConf.getOption(CONF_SPARK_YARN_SECONDARY_JARS)) } - userClassPath.foreach { x => - addFileToClasspath(sparkConf, x, null, env) + secondaryJars.foreach { x => + addFileToClasspath(sparkConf, conf, x, null, env) } } - addFileToClasspath(sparkConf, new URI(sparkJar(sparkConf)), SPARK_JAR, env) + addFileToClasspath(sparkConf, conf, new URI(sparkJar(sparkConf)), SPARK_JAR, env) populateHadoopClasspath(conf, env) sys.env.get(ENV_DIST_CLASSPATH).foreach { cp => addClasspathEntry(getClusterPath(sparkConf, cp), env) @@ -1169,16 +1250,20 @@ object Client extends Logging { * @param conf Spark configuration. */ def getUserClasspath(conf: SparkConf): Array[URI] = { - getUserClasspath(conf.getOption(CONF_SPARK_USER_JAR), - conf.getOption(CONF_SPARK_YARN_SECONDARY_JARS)) + val mainUri = getMainJarUri(conf.getOption(CONF_SPARK_USER_JAR)) + val secondaryUris = getSecondaryJarUris(conf.getOption(CONF_SPARK_YARN_SECONDARY_JARS)) + (mainUri ++ secondaryUris).toArray } - private def getUserClasspath( - mainJar: Option[String], - secondaryJars: Option[String]): Array[URI] = { - val mainUri = mainJar.orElse(Some(APP_JAR)).map(new URI(_)) - val secondaryUris = secondaryJars.map(_.split(",")).toSeq.flatten.map(new URI(_)) - (mainUri ++ secondaryUris).toArray + private def getMainJarUri(mainJar: Option[String]): Option[URI] = { + mainJar.flatMap { path => + val uri = Utils.resolveURI(path) + if (uri.getScheme == LOCAL_SCHEME) Some(uri) else None + }.orElse(Some(new URI(APP_JAR))) + } + + private def getSecondaryJarUris(secondaryJars: Option[String]): Seq[URI] = { + secondaryJars.map(_.split(",")).toSeq.flatten.map(new URI(_)) } /** @@ -1187,15 +1272,17 @@ object Client extends Logging { * If an alternate name for the file is given, and it's not a "local:" file, the alternate * name will be added to the classpath (relative to the job's work directory). * - * If not a "local:" file and no alternate name, the environment is not modified. + * If not a "local:" file and no alternate name, the linkName will be added to the classpath. * - * @param conf Spark configuration. - * @param uri URI to add to classpath (optional). - * @param fileName Alternate name for the file (optional). - * @param env Map holding the environment variables. + * @param conf Spark configuration. + * @param hadoopConf Hadoop configuration. + * @param uri URI to add to classpath (optional). + * @param fileName Alternate name for the file (optional). + * @param env Map holding the environment variables. */ private def addFileToClasspath( conf: SparkConf, + hadoopConf: Configuration, uri: URI, fileName: String, env: HashMap[String, String]): Unit = { @@ -1204,6 +1291,11 @@ object Client extends Logging { } else if (fileName != null) { addClasspathEntry(buildPath( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), fileName), env) + } else if (uri != null) { + val localPath = getQualifiedLocalPath(uri, hadoopConf) + val linkName = Option(uri.getFragment()).getOrElse(localPath.getName()) + addClasspathEntry(buildPath( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), linkName), env) } } @@ -1248,11 +1340,8 @@ object Client extends Logging { val mirror = universe.runtimeMirror(getClass.getClassLoader) try { - val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") - val hive = hiveClass.getMethod("get").invoke(null) - - val hiveConf = hiveClass.getMethod("getConf").invoke(hive) val hiveConfClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf") + val hiveConf = hiveConfClass.newInstance() val hiveConfGet = (param: String) => Option(hiveConfClass .getMethod("get", classOf[java.lang.String]) @@ -1262,6 +1351,9 @@ object Client extends Logging { // Check for local metastore if (metastore_uri != None && metastore_uri.get.toString.size > 0) { + val hiveClass = mirror.classLoader.loadClass("org.apache.hadoop.hive.ql.metadata.Hive") + val hive = hiveClass.getMethod("get").invoke(null, hiveConf.asInstanceOf[Object]) + val metastore_kerberos_principal_conf_var = mirror.classLoader .loadClass("org.apache.hadoop.hive.conf.HiveConf$ConfVars") .getField("METASTORE_KERBEROS_PRINCIPAL").get("varname").toString diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 4f42ffefa77f9..1165061db21e3 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -81,22 +81,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) .orNull // If dynamic allocation is enabled, start at the configured initial number of executors. // Default to minExecutors if no initialExecutors is set. - if (isDynamicAllocationEnabled) { - val minExecutorsConf = "spark.dynamicAllocation.minExecutors" - val initialExecutorsConf = "spark.dynamicAllocation.initialExecutors" - val maxExecutorsConf = "spark.dynamicAllocation.maxExecutors" - val minNumExecutors = sparkConf.getInt(minExecutorsConf, 0) - val initialNumExecutors = sparkConf.getInt(initialExecutorsConf, minNumExecutors) - val maxNumExecutors = sparkConf.getInt(maxExecutorsConf, Integer.MAX_VALUE) - - // If defined, initial executors must be between min and max - if (initialNumExecutors < minNumExecutors || initialNumExecutors > maxNumExecutors) { - throw new IllegalArgumentException( - s"$initialExecutorsConf must be between $minExecutorsConf and $maxNumExecutors!") - } - - numExecutors = initialNumExecutors - } + numExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf) principal = Option(principal) .orElse(sparkConf.getOption("spark.yarn.principal")) .orNull diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 9abd09b3cc7a5..2232ffba473b5 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -38,6 +38,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} +import org.apache.spark.launcher.YarnCommandBuilderUtils import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils @@ -199,6 +200,7 @@ class ExecutorRunnable( // For log4j configuration to reference javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR) + YarnCommandBuilderUtils.addPermGenSizeOpt(javaOpts) val userClassPath = Client.getUserClasspath(sparkConf).flatMap { uri => val absPath = diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 5f897cbcb4e9f..1deaa3743ddfa 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -21,8 +21,9 @@ import java.util.Collections import java.util.concurrent._ import java.util.regex.Pattern -import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConverters._ import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -36,8 +37,9 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} +import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.util.Utils /** @@ -87,11 +89,12 @@ private[yarn] class YarnAllocator( @volatile private var numExecutorsFailed = 0 @volatile private var targetNumExecutors = - if (Utils.isDynamicAllocationEnabled(sparkConf)) { - sparkConf.getInt("spark.dynamicAllocation.initialExecutors", 0) - } else { - sparkConf.getInt("spark.executor.instances", YarnSparkHadoopUtil.DEFAULT_NUMBER_EXECUTORS) - } + YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf) + + // Executor loss reason requests that are pending - maps from executor ID for inquiry to a + // list of requesters that should be responded to once we find out why the given executor + // was lost. + private val pendingLossReasonRequests = new HashMap[String, mutable.Buffer[RpcCallContext]] // Keep track of which container is running which executor to remove the executors later // Visible for testing. @@ -235,9 +238,7 @@ private[yarn] class YarnAllocator( val completedContainers = allocateResponse.getCompletedContainersStatuses() if (completedContainers.size > 0) { logDebug("Completed %d containers".format(completedContainers.size)) - processCompletedContainers(completedContainers.asScala) - logDebug("Finished processing %d completed containers. Current running executor count: %d." .format(completedContainers.size, numExecutorsRunning)) } @@ -429,39 +430,62 @@ private[yarn] class YarnAllocator( for (completedContainer <- completedContainers) { val containerId = completedContainer.getContainerId val alreadyReleased = releasedContainers.remove(containerId) - if (!alreadyReleased) { + val hostOpt = allocatedContainerToHostMap.get(containerId) + val onHostStr = hostOpt.map(host => s" on host: $host").getOrElse("") + val exitReason = if (!alreadyReleased) { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. numExecutorsRunning -= 1 - logInfo("Completed container %s (state: %s, exit status: %s)".format( + logInfo("Completed container %s%s (state: %s, exit status: %s)".format( containerId, + onHostStr, completedContainer.getState, completedContainer.getExitStatus)) // Hadoop 2.2.X added a ContainerExitStatus we should switch to use // there are some exit status' we shouldn't necessarily count against us, but for - // now I think its ok as none of the containers are expected to exit - if (completedContainer.getExitStatus == ContainerExitStatus.PREEMPTED) { - logInfo("Container preempted: " + containerId) - } else if (completedContainer.getExitStatus == -103) { // vmem limit exceeded - logWarning(memLimitExceededLogMessage( - completedContainer.getDiagnostics, - VMEM_EXCEEDED_PATTERN)) - } else if (completedContainer.getExitStatus == -104) { // pmem limit exceeded - logWarning(memLimitExceededLogMessage( - completedContainer.getDiagnostics, - PMEM_EXCEEDED_PATTERN)) - } else if (completedContainer.getExitStatus != 0) { - logInfo("Container marked as failed: " + containerId + - ". Exit status: " + completedContainer.getExitStatus + - ". Diagnostics: " + completedContainer.getDiagnostics) - numExecutorsFailed += 1 + // now I think its ok as none of the containers are expected to exit. + val exitStatus = completedContainer.getExitStatus + val (isNormalExit, containerExitReason) = exitStatus match { + case ContainerExitStatus.SUCCESS => + (true, s"Executor for container $containerId exited normally.") + case ContainerExitStatus.PREEMPTED => + // Preemption should count as a normal exit, since YARN preempts containers merely + // to do resource sharing, and tasks that fail due to preempted executors could + // just as easily finish on any other executor. See SPARK-8167. + (true, s"Container ${containerId}${onHostStr} was preempted.") + // Should probably still count memory exceeded exit codes towards task failures + case VMEM_EXCEEDED_EXIT_CODE => + (false, memLimitExceededLogMessage( + completedContainer.getDiagnostics, + VMEM_EXCEEDED_PATTERN)) + case PMEM_EXCEEDED_EXIT_CODE => + (false, memLimitExceededLogMessage( + completedContainer.getDiagnostics, + PMEM_EXCEEDED_PATTERN)) + case unknown => + numExecutorsFailed += 1 + (false, "Container marked as failed: " + containerId + onHostStr + + ". Exit status: " + completedContainer.getExitStatus + + ". Diagnostics: " + completedContainer.getDiagnostics) + + } + if (isNormalExit) { + logInfo(containerExitReason) + } else { + logWarning(containerExitReason) } + ExecutorExited(0, isNormalExit, containerExitReason) + } else { + // If we have already released this container, then it must mean + // that the driver has explicitly requested it to be killed + ExecutorExited(completedContainer.getExitStatus, isNormalExit = true, + s"Container $containerId exited from explicit termination request.") } - if (allocatedContainerToHostMap.contains(containerId)) { - val host = allocatedContainerToHostMap.get(containerId).get - val containerSet = allocatedHostToContainersMap.get(host).get - + for { + host <- hostOpt + containerSet <- allocatedHostToContainersMap.get(host) + } { containerSet.remove(containerId) if (containerSet.isEmpty) { allocatedHostToContainersMap.remove(host) @@ -474,18 +498,35 @@ private[yarn] class YarnAllocator( containerIdToExecutorId.remove(containerId).foreach { eid => executorIdToContainer.remove(eid) - + pendingLossReasonRequests.remove(eid).foreach { pendingRequests => + // Notify application of executor loss reasons so it can decide whether it should abort + pendingRequests.foreach(_.reply(exitReason)) + } if (!alreadyReleased) { // The executor could have gone away (like no route to host, node failure, etc) // Notify backend about the failure of the executor numUnexpectedContainerRelease += 1 - driverRef.send(RemoveExecutor(eid, - s"Yarn deallocated the executor $eid (container $containerId)")) + driverRef.send(RemoveExecutor(eid, exitReason)) } } } } + /** + * Register that some RpcCallContext has asked the AM why the executor was lost. Note that + * we can only find the loss reason to send back in the next call to allocateResources(). + */ + private[yarn] def enqueueGetLossReasonRequest( + eid: String, + context: RpcCallContext): Unit = synchronized { + if (executorIdToContainer.contains(eid)) { + pendingLossReasonRequests + .getOrElseUpdate(eid, new ArrayBuffer[RpcCallContext]) += context + } else { + logWarning(s"Tried to get the loss reason for non-existent executor $eid") + } + } + private def internalReleaseContainer(container: Container): Unit = { releasedContainers.add(container.getId()) amClient.releaseAssignedContainer(container.getId()) @@ -501,6 +542,8 @@ private object YarnAllocator { Pattern.compile(s"$MEM_REGEX of $MEM_REGEX physical memory used") val VMEM_EXCEEDED_PATTERN = Pattern.compile(s"$MEM_REGEX of $MEM_REGEX virtual memory used") + val VMEM_EXCEEDED_EXIT_CODE = -103 + val PMEM_EXCEEDED_EXIT_CODE = -104 def memLimitExceededLogMessage(diagnostics: String, pattern: Pattern): String = { val matcher = pattern.matcher(diagnostics) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index df042bf291de7..d2a211f6711ff 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -49,7 +49,6 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg * * @param conf The Yarn configuration. * @param sparkConf The Spark configuration. - * @param preferredNodeLocations Map with hints about where to allocate containers. * @param uiAddress Address of the SparkUI. * @param uiHistoryAddress Address of the application on the History Server. */ @@ -58,7 +57,6 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg driverRef: RpcEndpointRef, conf: YarnConfiguration, sparkConf: SparkConf, - preferredNodeLocations: Map[String, Set[SplitInfo]], uiAddress: String, uiHistoryAddress: String, securityMgr: SecurityManager diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 445d3dcd266db..f276e7efde9d7 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -314,5 +314,28 @@ object YarnSparkHadoopUtil { def getClassPathSeparator(): String = { classPathSeparatorField.get(null).asInstanceOf[String] } + + /** + * Getting the initial target number of executors depends on whether dynamic allocation is + * enabled. + */ + def getInitialTargetExecutorNumber(conf: SparkConf): Int = { + if (Utils.isDynamicAllocationEnabled(conf)) { + val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0) + val initialNumExecutors = + conf.getInt("spark.dynamicAllocation.initialExecutors", minNumExecutors) + val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", Int.MaxValue) + require(initialNumExecutors >= minNumExecutors && initialNumExecutors <= maxNumExecutors, + s"initial executor number $initialNumExecutors must between min executor number" + + s"$minNumExecutors and max executor number $maxNumExecutors") + + initialNumExecutors + } else { + val targetNumExecutors = + sys.env.get("SPARK_EXECUTOR_INSTANCES").map(_.toInt).getOrElse(DEFAULT_NUMBER_EXECUTORS) + // System property can override environment variable. + conf.getInt("spark.executor.instances", targetNumExecutors) + } + } } diff --git a/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala b/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala index 3ac36ef0a1c3f..7d246bf407121 100644 --- a/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala +++ b/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala @@ -17,11 +17,28 @@ package org.apache.spark.launcher +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer + /** - * Exposes needed methods + * Exposes methods from the launcher library that are used by the YARN backend. */ private[spark] object YarnCommandBuilderUtils { - def quoteForBatchScript(arg: String) : String = { + + def quoteForBatchScript(arg: String): String = { CommandBuilderUtils.quoteForBatchScript(arg) } + + /** + * Adds the perm gen configuration to the list of java options if needed and not yet added. + * + * Note that this method adds the option based on the local JVM version; if the node where + * the container is running has a different Java version, there's a risk that the option will + * not be added (e.g. if the AM is running Java 8 but the container's node is set up to use + * Java 7). + */ + def addPermGenSizeOpt(args: ListBuffer[String]): Unit = { + CommandBuilderUtils.addPermGenSizeOpt(args.asJava) + } + } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d06d95140438c..20771f655473c 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} import org.apache.spark.{SparkException, Logging, SparkContext} import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} +import org.apache.spark.launcher.SparkAppHandle import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class YarnClientSchedulerBackend( @@ -177,9 +178,18 @@ private[spark] class YarnClientSchedulerBackend( if (monitorThread != null) { monitorThread.stopMonitor() } + + // Report a final state to the launcher if one is connected. This is needed since in client + // mode this backend doesn't let the app monitor loop run to completion, so it does not report + // the final state itself. + // + // Note: there's not enough information at this point to provide a better final state, + // so assume the application was successful. + client.reportLauncherState(SparkAppHandle.State.FINISHED) + super.stop() - client.stop() YarnSparkHadoopUtil.get.stopExecutorDelegationTokenRenewer() + client.stop() logInfo("Stopped") } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 1aed5a1675075..50b699f11b21c 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -17,21 +17,13 @@ package org.apache.spark.scheduler.cluster -import java.net.NetworkInterface - import org.apache.hadoop.yarn.api.ApplicationConstants.Environment - -import scala.collection.JavaConverters._ - -import org.apache.hadoop.yarn.api.records.NodeState -import org.apache.hadoop.yarn.client.api.YarnClient import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.SparkContext import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.util.{IntParam, Utils} +import org.apache.spark.util.Utils private[spark] class YarnClusterSchedulerBackend( scheduler: TaskSchedulerImpl, @@ -40,13 +32,7 @@ private[spark] class YarnClusterSchedulerBackend( override def start() { super.start() - totalExpectedExecutors = DEFAULT_NUMBER_EXECUTORS - if (System.getenv("SPARK_EXECUTOR_INSTANCES") != null) { - totalExpectedExecutors = IntParam.unapply(System.getenv("SPARK_EXECUTOR_INSTANCES")) - .getOrElse(totalExpectedExecutors) - } - // System property can override environment variable. - totalExpectedExecutors = sc.getConf.getInt("spark.executor.instances", totalExpectedExecutors) + totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sc.conf) } override def applicationId(): String = diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties index 6b8a5dbf6373e..6b9a799954bf1 100644 --- a/yarn/src/test/resources/log4j.properties +++ b/yarn/src/test/resources/log4j.properties @@ -23,6 +23,9 @@ log4j.appender.file.file=target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +# Ignore messages below warning level from a few verbose libraries. +log4j.logger.com.sun.jersey=WARN log4j.logger.org.apache.hadoop=WARN +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.mortbay=WARN +log4j.logger.org.spark-project.jetty=WARN diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index 17c59ff06e0c1..12494b01054ba 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -22,15 +22,18 @@ import java.util.Properties import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ +import scala.concurrent.duration._ +import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.server.MiniYARNCluster import org.scalatest.{BeforeAndAfterAll, Matchers} +import org.scalatest.concurrent.Eventually._ import org.apache.spark._ -import org.apache.spark.launcher.TestClasspathBuilder +import org.apache.spark.launcher._ import org.apache.spark.util.Utils abstract class BaseYarnClusterSuite @@ -46,13 +49,14 @@ abstract class BaseYarnClusterSuite |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n |log4j.logger.org.apache.hadoop=WARN |log4j.logger.org.eclipse.jetty=WARN + |log4j.logger.org.mortbay=WARN |log4j.logger.org.spark-project.jetty=WARN """.stripMargin private var yarnCluster: MiniYARNCluster = _ protected var tempDir: File = _ private var fakeSparkJar: File = _ - private var hadoopConfDir: File = _ + protected var hadoopConfDir: File = _ private var logConfDir: File = _ def newYarnConfig(): YarnConfiguration @@ -120,15 +124,77 @@ abstract class BaseYarnClusterSuite clientMode: Boolean, klass: String, appArgs: Seq[String] = Nil, - sparkArgs: Seq[String] = Nil, + sparkArgs: Seq[(String, String)] = Nil, extraClassPath: Seq[String] = Nil, extraJars: Seq[String] = Nil, extraConf: Map[String, String] = Map(), - extraEnv: Map[String, String] = Map()): Unit = { + extraEnv: Map[String, String] = Map()): SparkAppHandle.State = { val master = if (clientMode) "yarn-client" else "yarn-cluster" - val props = new Properties() + val propsFile = createConfFile(extraClassPath = extraClassPath, extraConf = extraConf) + val env = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv + + val launcher = new SparkLauncher(env.asJava) + if (klass.endsWith(".py")) { + launcher.setAppResource(klass) + } else { + launcher.setMainClass(klass) + launcher.setAppResource(fakeSparkJar.getAbsolutePath()) + } + launcher.setSparkHome(sys.props("spark.test.home")) + .setMaster(master) + .setConf("spark.executor.instances", "1") + .setPropertiesFile(propsFile) + .addAppArgs(appArgs.toArray: _*) + + sparkArgs.foreach { case (name, value) => + if (value != null) { + launcher.addSparkArg(name, value) + } else { + launcher.addSparkArg(name) + } + } + extraJars.foreach(launcher.addJar) - props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) + val handle = launcher.startApplication() + try { + eventually(timeout(2 minutes), interval(1 second)) { + assert(handle.getState().isFinal()) + } + } finally { + handle.kill() + } + + handle.getState() + } + + /** + * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide + * any sort of error when the job process finishes successfully, but the job itself fails. So + * the tests enforce that something is written to a file after everything is ok to indicate + * that the job succeeded. + */ + protected def checkResult(finalState: SparkAppHandle.State, result: File): Unit = { + checkResult(finalState, result, "success") + } + + protected def checkResult( + finalState: SparkAppHandle.State, + result: File, + expected: String): Unit = { + finalState should be (SparkAppHandle.State.FINISHED) + val resultString = Files.toString(result, UTF_8) + resultString should be (expected) + } + + protected def mainClassName(klass: Class[_]): String = { + klass.getName().stripSuffix("$") + } + + protected def createConfFile( + extraClassPath: Seq[String] = Nil, + extraConf: Map[String, String] = Map()): String = { + val props = new Properties() + props.put("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) val testClasspath = new TestClasspathBuilder() .buildClassPath( @@ -138,69 +204,28 @@ abstract class BaseYarnClusterSuite .asScala .mkString(File.pathSeparator) - props.setProperty("spark.driver.extraClassPath", testClasspath) - props.setProperty("spark.executor.extraClassPath", testClasspath) + props.put("spark.driver.extraClassPath", testClasspath) + props.put("spark.executor.extraClassPath", testClasspath) // SPARK-4267: make sure java options are propagated correctly. props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") - yarnCluster.getConfig.asScala.foreach { e => + yarnCluster.getConfig().asScala.foreach { e => props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) } - sys.props.foreach { case (k, v) => if (k.startsWith("spark.")) { props.setProperty(k, v) } } - extraConf.foreach { case (k, v) => props.setProperty(k, v) } val propsFile = File.createTempFile("spark", ".properties", tempDir) val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) props.store(writer, "Spark properties.") writer.close() - - val extraJarArgs = if (extraJars.nonEmpty) Seq("--jars", extraJars.mkString(",")) else Nil - val mainArgs = - if (klass.endsWith(".py")) { - Seq(klass) - } else { - Seq("--class", klass, fakeSparkJar.getAbsolutePath()) - } - val argv = - Seq( - new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(), - "--master", master, - "--num-executors", "1", - "--properties-file", propsFile.getAbsolutePath()) ++ - extraJarArgs ++ - sparkArgs ++ - mainArgs ++ - appArgs - - Utils.executeAndGetOutput(argv, - extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv) - } - - /** - * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide - * any sort of error when the job process finishes successfully, but the job itself fails. So - * the tests enforce that something is written to a file after everything is ok to indicate - * that the job succeeded. - */ - protected def checkResult(result: File): Unit = { - checkResult(result, "success") - } - - protected def checkResult(result: File, expected: String): Unit = { - val resultString = Files.toString(result, UTF_8) - resultString should be (expected) - } - - protected def mainClassName(klass: Class[_]): String = { - klass.getName().stripSuffix("$") + propsFile.getAbsolutePath() } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index b5a42fd6afd98..6db012a77a936 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -19,19 +19,24 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URL +import java.util.{HashMap => JHashMap, Properties} import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually._ import org.apache.spark._ -import org.apache.spark.launcher.TestClasspathBuilder +import org.apache.spark.launcher._ import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded} import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.tags.ExtendedYarnTest import org.apache.spark.util.Utils /** @@ -39,6 +44,7 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ +@ExtendedYarnTest class YarnClusterSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = new YarnConfiguration() @@ -80,10 +86,8 @@ class YarnClusterSuite extends BaseYarnClusterSuite { test("run Spark in yarn-cluster mode unsuccessfully") { // Don't provide arguments so the driver will fail. - val exception = intercept[SparkException] { - runSpark(false, mainClassName(YarnClusterDriver.getClass)) - fail("Spark application should have failed.") - } + val finalState = runSpark(false, mainClassName(YarnClusterDriver.getClass)) + finalState should be (SparkAppHandle.State.FAILED) } test("run Python application in yarn-client mode") { @@ -102,11 +106,42 @@ class YarnClusterSuite extends BaseYarnClusterSuite { testUseClassPathFirst(false) } + test("monitor app using launcher library") { + val env = new JHashMap[String, String]() + env.put("YARN_CONF_DIR", hadoopConfDir.getAbsolutePath()) + + val propsFile = createConfFile() + val handle = new SparkLauncher(env) + .setSparkHome(sys.props("spark.test.home")) + .setConf("spark.ui.enabled", "false") + .setPropertiesFile(propsFile) + .setMaster("yarn-client") + .setAppResource("spark-internal") + .setMainClass(mainClassName(YarnLauncherTestApp.getClass)) + .startApplication() + + try { + eventually(timeout(30 seconds), interval(100 millis)) { + handle.getState() should be (SparkAppHandle.State.RUNNING) + } + + handle.getAppId() should not be (null) + handle.getAppId() should startWith ("application_") + handle.stop() + + eventually(timeout(30 seconds), interval(100 millis)) { + handle.getState() should be (SparkAppHandle.State.KILLED) + } + } finally { + handle.kill() + } + } + private def testBasicYarnApp(clientMode: Boolean): Unit = { val result = File.createTempFile("result", null, tempDir) - runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), + val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), appArgs = Seq(result.getAbsolutePath())) - checkResult(result) + checkResult(finalState, result) } private def testPySpark(clientMode: Boolean): Unit = { @@ -118,7 +153,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { // needed locations. val sparkHome = sys.props("spark.test.home"); val pythonPath = Seq( - s"$sparkHome/python/lib/py4j-0.8.2.1-src.zip", + s"$sparkHome/python/lib/py4j-0.9-src.zip", s"$sparkHome/python") val extraEnv = Map( "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), @@ -141,11 +176,11 @@ class YarnClusterSuite extends BaseYarnClusterSuite { val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") val result = File.createTempFile("result", null, tempDir) - runSpark(clientMode, primaryPyFile.getAbsolutePath(), - sparkArgs = Seq("--py-files", pyFiles), + val finalState = runSpark(clientMode, primaryPyFile.getAbsolutePath(), + sparkArgs = Seq("--py-files" -> pyFiles), appArgs = Seq(result.getAbsolutePath()), extraEnv = extraEnv) - checkResult(result) + checkResult(finalState, result) } private def testUseClassPathFirst(clientMode: Boolean): Unit = { @@ -154,15 +189,15 @@ class YarnClusterSuite extends BaseYarnClusterSuite { val userJar = TestUtils.createJarWithFiles(Map("test.resource" -> "OVERRIDDEN"), tempDir) val driverResult = File.createTempFile("driver", null, tempDir) val executorResult = File.createTempFile("executor", null, tempDir) - runSpark(clientMode, mainClassName(YarnClasspathTest.getClass), + val finalState = runSpark(clientMode, mainClassName(YarnClasspathTest.getClass), appArgs = Seq(driverResult.getAbsolutePath(), executorResult.getAbsolutePath()), extraClassPath = Seq(originalJar.getPath()), extraJars = Seq("local:" + userJar.getPath()), extraConf = Map( "spark.driver.userClassPathFirst" -> "true", "spark.executor.userClassPathFirst" -> "true")) - checkResult(driverResult, "OVERRIDDEN") - checkResult(executorResult, "OVERRIDDEN") + checkResult(finalState, driverResult, "OVERRIDDEN") + checkResult(finalState, executorResult, "OVERRIDDEN") } } @@ -209,8 +244,8 @@ private object YarnClusterDriver extends Logging with Matchers { data should be (Set(1, 2, 3, 4)) result = "success" } finally { - sc.stop() Files.write(result, status, UTF_8) + sc.stop() } // verify log urls are present @@ -295,3 +330,18 @@ private object YarnClasspathTest extends Logging { } } + +private object YarnLauncherTestApp { + + def main(args: Array[String]): Unit = { + // Do not stop the application; the test will stop it using the launcher lib. Just run a task + // that will prevent the process from exiting. + val sc = new SparkContext(new SparkConf()) + sc.parallelize(Seq(1)).foreach { i => + this.synchronized { + wait() + } + } + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index 8d9c9b3004eda..c17e8695c24fb 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -28,10 +28,12 @@ import org.scalatest.Matchers import org.apache.spark._ import org.apache.spark.network.shuffle.ShuffleTestAccessor import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} +import org.apache.spark.tags.ExtendedYarnTest /** * Integration test for the external shuffle service with a yarn mini-cluster */ +@ExtendedYarnTest class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = { @@ -51,7 +53,7 @@ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { logInfo("Shuffle service port = " + shuffleServicePort) val result = File.createTempFile("result", null, tempDir) - runSpark( + val finalState = runSpark( false, mainClassName(YarnExternalShuffleDriver.getClass), appArgs = Seq(result.getAbsolutePath(), registeredExecFile.getAbsolutePath), @@ -60,7 +62,7 @@ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { "spark.shuffle.service.port" -> shuffleServicePort.toString ) ) - checkResult(result) + checkResult(finalState, result) assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists()) } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index 49bee0866dd43..e1c67db76571f 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.Matchers import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils @@ -233,4 +234,15 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging } assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer") } + + test("check different hadoop utils based on env variable") { + try { + System.setProperty("SPARK_YARN_MODE", "true") + assert(SparkHadoopUtil.get.getClass === classOf[YarnSparkHadoopUtil]) + System.setProperty("SPARK_YARN_MODE", "false") + assert(SparkHadoopUtil.get.getClass === classOf[SparkHadoopUtil]) + } finally { + System.clearProperty("SPARK_YARN_MODE") + } + } }
Input-Succeeded