Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
package org.apache.spark.sql.catalyst.analysis

import java.util.Locale
import java.util.concurrent.ConcurrentHashMap
import javax.annotation.concurrent.GuardedBy

import scala.jdk.CollectionConverters._
import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.spark.SparkUnsupportedOperationException
Expand Down Expand Up @@ -195,8 +195,9 @@ object FunctionRegistryBase {

trait SimpleFunctionRegistryBase[T] extends FunctionRegistryBase[T] with Logging {

@GuardedBy("this")
protected val functionBuilders =
new ConcurrentHashMap[FunctionIdentifier, (ExpressionInfo, FunctionBuilder)]
new mutable.HashMap[FunctionIdentifier, (ExpressionInfo, FunctionBuilder)]

// Resolution of the function name is always case insensitive, but the database name
// depends on the caller
Expand All @@ -219,36 +220,45 @@ trait SimpleFunctionRegistryBase[T] extends FunctionRegistryBase[T] with Logging
def internalRegisterFunction(
name: FunctionIdentifier,
info: ExpressionInfo,
builder: FunctionBuilder): Unit = {
builder: FunctionBuilder): Unit = synchronized {
val newFunction = (info, builder)
functionBuilders.put(name, newFunction) match {
case previousFunction if previousFunction != null =>
case Some(previousFunction) if previousFunction != newFunction =>
logWarning(log"The function ${MDC(FUNCTION_NAME, name)} replaced a " +
log"previously registered function.")
case _ =>
}
}

override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): T = {
val func = Option(functionBuilders.get(normalizeFuncName(name))).map(_._2).getOrElse {
throw QueryCompilationErrors.unresolvedRoutineError(name, Seq("system.builtin"))
val func = synchronized {
functionBuilders.get(normalizeFuncName(name)).map(_._2).getOrElse {
throw QueryCompilationErrors.unresolvedRoutineError(name, Seq("system.builtin"))
}
}
func(children)
}

override def listFunction(): Seq[FunctionIdentifier] =
functionBuilders.keys().asScala.toSeq
override def listFunction(): Seq[FunctionIdentifier] = synchronized {
functionBuilders.iterator.map(_._1).toList
}

override def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] =
Option(functionBuilders.get(normalizeFuncName(name))).map(_._1)
override def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] = synchronized {
functionBuilders.get(normalizeFuncName(name)).map(_._1)
}

override def lookupFunctionBuilder(name: FunctionIdentifier): Option[FunctionBuilder] =
Option(functionBuilders.get(normalizeFuncName(name))).map(_._2)
override def lookupFunctionBuilder(
name: FunctionIdentifier): Option[FunctionBuilder] = synchronized {
functionBuilders.get(normalizeFuncName(name)).map(_._2)
}

override def dropFunction(name: FunctionIdentifier): Boolean =
Option(functionBuilders.remove(normalizeFuncName(name))).isDefined
override def dropFunction(name: FunctionIdentifier): Boolean = synchronized {
functionBuilders.remove(normalizeFuncName(name)).isDefined
}

override def clear(): Unit = functionBuilders.clear()
override def clear(): Unit = synchronized {
functionBuilders.clear()
}
}

/**
Expand Down Expand Up @@ -298,11 +308,7 @@ class SimpleFunctionRegistry

override def clone(): SimpleFunctionRegistry = synchronized {
val registry = new SimpleFunctionRegistry
val iterator = functionBuilders.entrySet().iterator()
while (iterator.hasNext) {
val entry = iterator.next()
val name = entry.getKey
val (info, builder) = entry.getValue
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
registry.internalRegisterFunction(name, info, builder)
}
registry
Expand Down Expand Up @@ -1030,11 +1036,7 @@ class SimpleTableFunctionRegistry extends SimpleFunctionRegistryBase[LogicalPlan

override def clone(): SimpleTableFunctionRegistry = synchronized {
val registry = new SimpleTableFunctionRegistry
val iterator = functionBuilders.entrySet().iterator()
while (iterator.hasNext) {
val entry = iterator.next()
val name = entry.getKey
val (info, builder) = entry.getValue
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
registry.internalRegisterFunction(name, info, builder)
}
registry
Expand Down