From a6e2b9740f91347a9134185e0a15faf7462f5175 Mon Sep 17 00:00:00 2001 From: Tom McCormick Date: Mon, 14 Jul 2025 17:46:55 -0400 Subject: [PATCH 1/6] Add RPC header for access token --- .../java/org/apache/hadoop/ipc/Server.java | 480 +++++++++--------- .../hadoop/security/AuthorizationContext.java | 22 + .../org/apache/hadoop/util/ProtoUtil.java | 7 + .../src/main/proto/RpcHeader.proto | 2 + .../security/TestAuthorizationContext.java | 52 ++ .../TestAuthorizationHeaderPropagation.java | 64 +++ 6 files changed, 394 insertions(+), 233 deletions(-) create mode 100644 hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/AuthorizationContext.java create mode 100644 hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/security/TestAuthorizationContext.java create mode 100644 hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/server/namenode/TestAuthorizationHeaderPropagation.java diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java index 289403d942bd1..f76212e3c76b8 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java @@ -139,11 +139,12 @@ import org.apache.hadoop.thirdparty.protobuf.Message; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.hadoop.security.AuthorizationContext; /** An abstract IPC service. IPC calls take a single {@link Writable} as a * parameter, and return a {@link Writable} as their value. A service runs on * a port and is defined by a parameter class and a value class. - * + * * @see Client */ @Public @@ -208,7 +209,7 @@ static class ExceptionsHandler { /** * Add exception classes for which server won't log stack traces. * Optimized for infrequent invocation. - * @param exceptionClass exception classes + * @param exceptionClass exception classes */ void addTerseLoggingExceptions(Class... exceptionClass) { terseExceptions.addAll(Arrays @@ -239,14 +240,14 @@ boolean isSuppressedLog(Class t) { } - + /** * If the user accidentally sends an HTTP GET to an IPC port, we detect this * and send back a nicer response. */ private static final ByteBuffer HTTP_GET_BYTES = ByteBuffer.wrap( "GET ".getBytes(StandardCharsets.UTF_8)); - + /** * An HTTP response to send back if we detect an HTTP request to our IPC * port. @@ -261,7 +262,7 @@ boolean isSuppressedLog(Class t) { * Initial and max size of response buffer */ static int INITIAL_RESP_BUF_SIZE = 10240; - + static class RpcKindMapValue { final Class rpcRequestWrapperClass; final RpcInvoker rpcInvoker; @@ -270,42 +271,42 @@ static class RpcKindMapValue { RpcInvoker rpcInvoker) { this.rpcInvoker = rpcInvoker; this.rpcRequestWrapperClass = rpcRequestWrapperClass; - } + } } static Map rpcKindMap = new HashMap<>(4); - - + + /** * Register a RPC kind and the class to deserialize the rpc request. - * + * * Called by static initializers of rpcKind Engines * @param rpcKind - input rpcKind. * @param rpcRequestWrapperClass - this class is used to deserialze the * the rpc request. * @param rpcInvoker - use to process the calls on SS. */ - - public static void registerProtocolEngine(RPC.RpcKind rpcKind, + + public static void registerProtocolEngine(RPC.RpcKind rpcKind, Class rpcRequestWrapperClass, RpcInvoker rpcInvoker) { - RpcKindMapValue old = + RpcKindMapValue old = rpcKindMap.put(rpcKind, new RpcKindMapValue(rpcRequestWrapperClass, rpcInvoker)); if (old != null) { rpcKindMap.put(rpcKind, old); throw new IllegalArgumentException("ReRegistration of rpcKind: " + - rpcKind); + rpcKind); } LOG.debug("rpcKind={}, rpcRequestWrapperClass={}, rpcInvoker={}.", rpcKind, rpcRequestWrapperClass, rpcInvoker); } - + public Class getRpcRequestWrapper( RpcKindProto rpcKind) { if (rpcRequestClass != null) return rpcRequestClass; RpcKindMapValue val = rpcKindMap.get(ProtoUtil.convert(rpcKind)); - return (val == null) ? null : val.rpcRequestWrapperClass; + return (val == null) ? null : val.rpcRequestWrapperClass; } protected RpcInvoker getServerRpcInvoker(RPC.RpcKind rpcKind) { @@ -314,22 +315,22 @@ protected RpcInvoker getServerRpcInvoker(RPC.RpcKind rpcKind) { public static RpcInvoker getRpcInvoker(RPC.RpcKind rpcKind) { RpcKindMapValue val = rpcKindMap.get(rpcKind); - return (val == null) ? null : val.rpcInvoker; + return (val == null) ? null : val.rpcInvoker; } - + public static final Logger LOG = LoggerFactory.getLogger(Server.class); public static final Logger AUDITLOG = LoggerFactory.getLogger("SecurityLogger."+Server.class.getName()); private static final String AUTH_FAILED_FOR = "Auth failed for "; private static final String AUTH_SUCCESSFUL_FOR = "Auth successful for "; - + private static final ThreadLocal SERVER = new ThreadLocal(); - private static final Map> PROTOCOL_CACHE = + private static final Map> PROTOCOL_CACHE = new ConcurrentHashMap>(); - - static Class getProtocolClass(String protocolName, Configuration conf) + + static Class getProtocolClass(String protocolName, Configuration conf) throws ClassNotFoundException { Class protocol = PROTOCOL_CACHE.get(protocolName); if (protocol == null) { @@ -338,7 +339,7 @@ static Class getProtocolClass(String protocolName, Configuration conf) } return protocol; } - + /** @return Returns the server instance called under or null. May be called under * {@link #call(Writable, long)} implementations, and under {@link Writable} * methods of paramters and return values. Permits applications to access @@ -346,7 +347,7 @@ static Class getProtocolClass(String protocolName, Configuration conf) public static Server get() { return SERVER.get(); } - + /** This is set to Call object before Handler invokes an RPC and reset * after the call returns. */ @@ -362,14 +363,14 @@ public static ThreadLocal getCurCall() { * Returns the currently active RPC call's sequential ID number. A negative * call ID indicates an invalid value, such as if there is no currently active * RPC call. - * + * * @return int sequential ID number of currently active RPC call */ public static int getCallId() { Call call = CurCall.get(); return call != null ? call.callId : RpcConstants.INVALID_CALL_ID; } - + /** * @return The current active RPC call's retry count. -1 indicates the retry * cache is not supported in the client side. @@ -708,7 +709,7 @@ void updateDeferredMetrics(Call call, String name) { } /** - * A convenience method to bind to a given address and report + * A convenience method to bind to a given address and report * better exceptions if the address is not a valid host. * @param socket the socket to bind * @param address the address to bind to @@ -717,12 +718,12 @@ void updateDeferredMetrics(Call call, String name) { * @throws UnknownHostException if the address isn't a valid host name * @throws IOException other random errors from bind */ - public static void bind(ServerSocket socket, InetSocketAddress address, + public static void bind(ServerSocket socket, InetSocketAddress address, int backlog) throws IOException { bind(socket, address, backlog, null, null); } - public static void bind(ServerSocket socket, InetSocketAddress address, + public static void bind(ServerSocket socket, InetSocketAddress address, int backlog, Configuration conf, String rangeConf) throws IOException { try { IntegerRanges range = null; @@ -782,7 +783,7 @@ public RpcMetrics getRpcMetrics() { public RpcDetailedMetrics getRpcDetailedMetrics() { return rpcDetailedMetrics; } - + @VisibleForTesting Iterable getHandlers() { return Arrays.asList(handlers); @@ -1457,7 +1458,7 @@ public String toString() { /** Listens on the socket. Creates jobs for the handler threads*/ private class Listener extends Thread { - + private ServerSocketChannel acceptChannel = null; //the accept channel private Selector selector = null; //the selector that we use for the server private Reader[] readers = null; @@ -1504,7 +1505,7 @@ private class Listener extends Thread { void setIsAuxiliary() { this.isOnAuxiliaryPort = true; } - + private class Reader extends Thread { final private BlockingQueue pendingConnections; private final Selector readSelector; @@ -1516,7 +1517,7 @@ private class Reader extends Thread { new LinkedBlockingQueue(readerPendingConnectionQueue); this.readSelector = Selector.open(); } - + @Override public void run() { LOG.info("Starting " + Thread.currentThread().getName()); @@ -1620,7 +1621,7 @@ public void run() { } } catch (OutOfMemoryError e) { // we can run out of memory if we have too many threads - // log the event and sleep for a minute and give + // log the event and sleep for a minute and give // some thread(s) a chance to finish LOG.warn("Out of Memory in server select", e); closeCurrentConnection(key, e); @@ -1640,7 +1641,7 @@ public void run() { selector= null; acceptChannel= null; - + // close all connections connectionManager.stopIdleScan(); connectionManager.closeAll(); @@ -1660,7 +1661,7 @@ private void closeCurrentConnection(SelectionKey key, Throwable e) { InetSocketAddress getAddress() { return (InetSocketAddress)acceptChannel.socket().getLocalSocketAddress(); } - + void doAccept(SelectionKey key) throws InterruptedException, IOException, OutOfMemoryError { ServerSocketChannel server = (ServerSocketChannel) key.channel(); SocketChannel channel; @@ -1669,7 +1670,7 @@ void doAccept(SelectionKey key) throws InterruptedException, IOException, OutOf channel.configureBlocking(false); channel.socket().setTcpNoDelay(tcpNoDelay); channel.socket().setKeepAlive(true); - + Reader reader = getReader(); Connection c = connectionManager.register(channel, this.listenPort, this.isOnAuxiliaryPort); @@ -1690,10 +1691,10 @@ void doRead(SelectionKey key) throws InterruptedException { int count; Connection c = (Connection)key.attachment(); if (c == null) { - return; + return; } c.setLastContact(Time.now()); - + try { count = c.readAndProcess(); } catch (InterruptedException ieo) { @@ -1716,7 +1717,7 @@ void doRead(SelectionKey key) throws InterruptedException { else { c.setLastContact(Time.now()); } - } + } synchronized void doStop() { if (selector != null) { @@ -1734,7 +1735,7 @@ synchronized void doStop() { r.shutdown(); } } - + synchronized Selector getSelector() { return selector; } // The method that will return the next reader to work with // Simplistic implementation of round robin for now @@ -1771,7 +1772,7 @@ public void run() { } } } - + private void doRunLoop() { long lastPurgeTimeNanos = 0; // last check for old calls. @@ -1812,7 +1813,7 @@ private void doRunLoop() { // LOG.debug("Checking for old call responses."); ArrayList calls; - + // get the list of channels from list of keys. synchronized (writeSelector.keys()) { calls = new ArrayList(writeSelector.keys().size()); @@ -1820,7 +1821,7 @@ private void doRunLoop() { while (iter.hasNext()) { SelectionKey key = iter.next(); RpcCall call = (RpcCall)key.attachment(); - if (call != null && key.channel() == call.connection.channel) { + if (call != null && key.channel() == call.connection.channel) { calls.add(call); } } @@ -1869,7 +1870,7 @@ private void doAsyncWrite(SelectionKey key) throws IOException { } // - // Remove calls that have been pending in the responseQueue + // Remove calls that have been pending in the responseQueue // for a long time. // private void doPurge(RpcCall call, long now) { @@ -1932,18 +1933,18 @@ private boolean processResponse(LinkedList responseQueue, Thread.currentThread().getName(), call, numBytes); } else { // - // If we were unable to write the entire response out, then - // insert in Selector queue. + // If we were unable to write the entire response out, then + // insert in Selector queue. // call.connection.responseQueue.addFirst(call); - + if (inHandler) { // set the serve time when the response has to be sent later call.responseTimestampNanos = Time.monotonicNowNanos(); - + incPending(); try { - // Wakeup the thread blocked on select, only then can the call + // Wakeup the thread blocked on select, only then can the call // to channel.register() complete. writeSelector.wakeup(); channel.register(writeSelector, SelectionKey.OP_WRITE, call); @@ -2006,12 +2007,12 @@ private synchronized void waitPending() throws InterruptedException { public enum AuthProtocol { NONE(0), SASL(-33); - + public final int callId; AuthProtocol(int callId) { this.callId = callId; } - + static AuthProtocol valueOf(int callId) { for (AuthProtocol authType : AuthProtocol.values()) { if (authType.callId == callId) { @@ -2021,12 +2022,12 @@ static AuthProtocol valueOf(int callId) { return null; } }; - + /** * Wrapper for RPC IOExceptions to be returned to the client. Used to * let exceptions bubble up to top of processOneRpc where the correct * callId can be associated with the response. Also used to prevent - * unnecessary stack trace logging if it's not an internal server error. + * unnecessary stack trace logging if it's not an internal server error. */ @SuppressWarnings("serial") private static class FatalRpcServerException extends RpcServerException { @@ -2068,7 +2069,7 @@ public class Connection { private int dataLength; private Socket socket; - // Cache the remote host & port info so that even if the socket is + // Cache the remote host & port info so that even if the socket is // disconnected, we can say where it used to connect to. /** @@ -2108,13 +2109,13 @@ public class Connection { private boolean sentNegotiate = false; private boolean useWrap = false; - + public Connection(SocketChannel channel, long lastContact, int ingressPort, boolean isOnAuxiliaryPort) { this.channel = channel; this.lastContact = lastContact; this.data = null; - + // the buffer is initialized to read the "hrpc" and after that to read // the length of the Rpc-packet (i.e 4 bytes) this.dataLengthBuffer = ByteBuffer.allocate(4); @@ -2140,7 +2141,7 @@ public Connection(SocketChannel channel, long lastContact, socketSendBufferSize); } } - } + } @Override public String toString() { @@ -2199,17 +2200,17 @@ public Configuration getConf() { private boolean isIdle() { return rpcCount.get() == 0; } - + /* Decrement the outstanding RPC count */ private void decRpcCount() { rpcCount.decrementAndGet(); } - + /* Increment the outstanding RPC count */ private void incRpcCount() { rpcCount.incrementAndGet(); } - + private UserGroupInformation getAuthorizedUgi(String authorizedId) throws InvalidToken, AccessControlException { if (authMethod == AuthMethod.TOKEN) { @@ -2252,7 +2253,7 @@ private void saslReadAndProcess(RpcWritable.Buffer buffer) throws * that are wrapped as a cause of parameter e are unwrapped so that they can * be sent as the true cause to the client side. In case of * {@link InvalidToken} we go one level deeper to get the true cause. - * + * * @param e the exception that may have a cause we want to unwrap. * @return the true cause for some exceptions. */ @@ -2268,7 +2269,7 @@ private Throwable getTrueCause(IOException e) { // callbacks to only returning InvalidToken, but some services // need to throw other exceptions (ex. NN + StandyException), // so for now we'll tunnel the real exceptions via an - // InvalidToken's cause which normally is not set + // InvalidToken's cause which normally is not set if (cause.getCause() != null) { cause = cause.getCause(); } @@ -2278,15 +2279,15 @@ private Throwable getTrueCause(IOException e) { } return e; } - + /** * Process saslMessage and send saslResponse back * @param saslMessage received SASL message * @throws RpcServerException setup failed due to SASL negotiation - * failure, premature or invalid connection context, or other state - * errors. This exception needs to be sent to the client. This - * exception will wrap {@link RetriableException}, - * {@link InvalidToken}, {@link StandbyException} or + * failure, premature or invalid connection context, or other state + * errors. This exception needs to be sent to the client. This + * exception will wrap {@link RetriableException}, + * {@link InvalidToken}, {@link StandbyException} or * {@link SaslException}. * @throws IOException if sending reply fails * @throws InterruptedException @@ -2330,7 +2331,7 @@ private void saslProcess(RpcSaslProto saslMessage) throw tce; } } - + if (saslServer != null && saslServer.isComplete()) { if (LOG.isDebugEnabled()) { LOG.debug("SASL server context established. Negotiated QoP is {}.", @@ -2364,15 +2365,15 @@ private void saslProcess(RpcSaslProto saslMessage) } } } - + /** * Process a saslMessge. * @param saslMessage received SASL message * @return the sasl response to send back to client - * @throws SaslException if authentication or generating response fails, + * @throws SaslException if authentication or generating response fails, * or SASL protocol mixup * @throws IOException if a SaslServer cannot be created - * @throws AccessControlException if the requested authentication type + * @throws AccessControlException if the requested authentication type * is not supported or trying to re-attempt negotiation. * @throws InterruptedException */ @@ -2380,7 +2381,7 @@ private RpcSaslProto processSaslMessage(RpcSaslProto saslMessage) throws SaslException, IOException, AccessControlException, InterruptedException { final RpcSaslProto saslResponse; - final SaslState state = saslMessage.getState(); // required + final SaslState state = saslMessage.getState(); // required switch (state) { case NEGOTIATE: { if (sentNegotiate) { @@ -2503,7 +2504,7 @@ private void checkDataLength(int dataLength) throws IOException { throw new IOException(error); } else if (dataLength > maxDataLength) { String error = "Requested data length " + dataLength + - " is longer than maximum configured RPC length " + + " is longer than maximum configured RPC length " + maxDataLength + ". RPC came from " + getHostAddress(); LOG.warn(error); throw new IOException(error); @@ -2511,17 +2512,17 @@ private void checkDataLength(int dataLength) throws IOException { } /** - * This method reads in a non-blocking fashion from the channel: - * this method is called repeatedly when data is present in the channel; + * This method reads in a non-blocking fashion from the channel: + * this method is called repeatedly when data is present in the channel; * when it has enough data to process one rpc it processes that rpc. - * - * On the first pass, it processes the connectionHeader, - * connectionContext (an outOfBand RPC) and at most one RPC request that + * + * On the first pass, it processes the connectionHeader, + * connectionContext (an outOfBand RPC) and at most one RPC request that * follows that. On future passes it will process at most one RPC request. - * - * Quirky things: dataLengthBuffer (4 bytes) is used to read "hrpc" OR + * + * Quirky things: dataLengthBuffer (4 bytes) is used to read "hrpc" OR * rpc request length. - * + * * @return -1 in case of error, else num bytes read so far * @throws IOException - internal error that should not be returned to * client, typically failure to respond to client @@ -2532,11 +2533,11 @@ public int readAndProcess() throws IOException, InterruptedException { // dataLengthBuffer is used to read "hrpc" or the rpc-packet length int count = -1; if (dataLengthBuffer.remaining() > 0) { - count = channelRead(channel, dataLengthBuffer); - if (count < 0 || dataLengthBuffer.remaining() > 0) + count = channelRead(channel, dataLengthBuffer); + if (count < 0 || dataLengthBuffer.remaining() > 0) return count; } - + if (!connectionHeaderRead) { // Every connection is expected to send the header; // so far we read "hrpc" of the connection header. @@ -2552,7 +2553,7 @@ public int readAndProcess() throws IOException, InterruptedException { // TODO we should add handler for service class later this.setServiceClass(connectionHeaderBuf.get(1)); dataLengthBuffer.flip(); - + // Check if it looks like the user is hitting an IPC port // with an HTTP GET - this is a common error, so we can // send back a simple string indicating as much. @@ -2578,16 +2579,16 @@ public int readAndProcess() throws IOException, InterruptedException { setupBadVersionResponse(version); return -1; } - + // this may switch us into SIMPLE - authProtocol = initializeAuthContext(connectionHeaderBuf.get(2)); - + authProtocol = initializeAuthContext(connectionHeaderBuf.get(2)); + dataLengthBuffer.clear(); // clear to next read rpc packet len connectionHeaderBuf = null; connectionHeaderRead = true; continue; // connection header read, now read 4 bytes rpc packet len } - + if (data == null) { // just read 4 bytes - length of RPC packet dataLengthBuffer.flip(); dataLength = dataLengthBuffer.getInt(); @@ -2597,7 +2598,7 @@ public int readAndProcess() throws IOException, InterruptedException { } // Now read the RPC packet count = channelRead(channel, data); - + if (data.remaining() == 0) { dataLengthBuffer.clear(); // to read length of future rpc packets data.flip(); @@ -2610,7 +2611,7 @@ public int readAndProcess() throws IOException, InterruptedException { if (!isHeaderRead) { continue; } - } + } return count; } return -1; @@ -2622,7 +2623,7 @@ private AuthProtocol initializeAuthContext(int authType) if (authProtocol == null) { IOException ioe = new IpcException("Unknown auth protocol:" + authType); doSaslReply(ioe); - throw ioe; + throw ioe; } boolean isSimpleEnabled = enabledAuthMethods.contains(AuthMethod.SIMPLE); switch (authProtocol) { @@ -2645,10 +2646,10 @@ private AuthProtocol initializeAuthContext(int authType) } /** - * Process the Sasl's Negotiate request, including the optimization of + * Process the Sasl's Negotiate request, including the optimization of * accelerating token negotiation. - * @return the response to Negotiate request - the list of enabled - * authMethods and challenge if the TOKENS are supported. + * @return the response to Negotiate request - the list of enabled + * authMethods and challenge if the TOKENS are supported. * @throws SaslException - if attempt to generate challenge fails. * @throws IOException - if it fails to create the SASL server for Tokens */ @@ -2670,27 +2671,27 @@ private RpcSaslProto buildSaslNegotiateResponse() sentNegotiate = true; return negotiateMessage; } - + private SaslServer createSaslServer(AuthMethod authMethod) throws IOException, InterruptedException { final Map saslProps = saslPropsResolver.getServerProperties(addr, ingressPort); return new SaslRpcServer(authMethod).create(this, saslProps, secretManager); } - + /** * Try to set up the response to indicate that the client version * is incompatible with the server. This can contain special-case * code to speak enough of past IPC protocols to pass back * an exception to the caller. - * @param clientVersion the version the caller is using + * @param clientVersion the version the caller is using * @throws IOException */ private void setupBadVersionResponse(int clientVersion) throws IOException { String errMsg = "Server IPC version " + CURRENT_VERSION + " cannot communicate with client version " + clientVersion; ByteArrayOutputStream buffer = new ByteArrayOutputStream(); - + if (clientVersion >= 9) { // Versions >>9 understand the normal response RpcCall fakeCall = new RpcCall(this, -1); @@ -2716,7 +2717,7 @@ private void setupBadVersionResponse(int clientVersion) throws IOException { sendResponse(fakeCall); } } - + private void setupHttpRequestOnIpcPortResponse() throws IOException { RpcCall fakeCall = new RpcCall(this, 0); fakeCall.setResponse(ByteBuffer.wrap( @@ -2727,7 +2728,7 @@ private void setupHttpRequestOnIpcPortResponse() throws IOException { /** Reads the connection context following the connection header * @throws RpcServerException - if the header cannot be * deserialized, or the user is not authorized - */ + */ private void processConnectionContext(RpcWritable.Buffer buffer) throws RpcServerException { // allow only one connection context during a session @@ -2747,7 +2748,7 @@ private void processConnectionContext(RpcWritable.Buffer buffer) // user is authenticated user.setAuthenticationMethod(authMethod); //Now we check if this is a proxy user case. If the protocol user is - //different from the 'user', it is a proxy user scenario. However, + //different from the 'user', it is a proxy user scenario. However, //this is not allowed if user authenticated with DIGEST. if ((protocolUser != null) && (!protocolUser.getUserName().equals(user.getUserName()))) { @@ -2775,14 +2776,14 @@ private void processConnectionContext(RpcWritable.Buffer buffer) connectionManager.incrUserConnections(user.getShortUserName()); } } - + /** * Process a wrapped RPC Request - unwrap the SASL packet and process - * each embedded RPC request + * each embedded RPC request * @param inBuf - SASL wrapped request of one or more RPCs * @throws IOException - SASL packet cannot be unwrapped * @throws InterruptedException - */ + */ private void unwrapPacketAndProcessRpcs(byte[] inBuf) throws IOException, InterruptedException { LOG.debug("Have read input token of size {} for processing by saslServer.unwrap()", @@ -2818,18 +2819,18 @@ private void unwrapPacketAndProcessRpcs(byte[] inBuf) } } } - + /** - * Process one RPC Request from buffer read from socket stream + * Process one RPC Request from buffer read from socket stream * - decode rpc in a rpc-Call * - handle out-of-band RPC requests such as the initial connectionContext * - A successfully decoded RpcCall will be deposited in RPC-Q and * its response will be sent later when the request is processed. - * + * * Prior to this call the connectionHeader ("hrpc...") has been handled and * if SASL then SASL has been established and the buf we are passed * has been unwrapped from SASL. - * + * * @param bb - contains the RPC request header and the rpc request * @throws IOException - internal error that should not be returned to * client, typically failure to respond to client @@ -2887,15 +2888,15 @@ private void checkRpcHeaders(RpcRequestHeaderProto header) throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err); } - if (header.getRpcOp() != + if (header.getRpcOp() != RpcRequestHeaderProto.OperationProto.RPC_FINAL_PACKET) { - String err = "IPC Server does not implement rpc header operation" + + String err = "IPC Server does not implement rpc header operation" + header.getRpcOp(); throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err); } // If we know the rpc kind, get its class so that we can deserialize - // (Note it would make more sense to have the handler deserialize but + // (Note it would make more sense to have the handler deserialize but // we continue with this original design. if (!header.hasRpcKind()) { String err = " IPC Server: No rpc kind in rpcRequestHeader"; @@ -2905,7 +2906,7 @@ private void checkRpcHeaders(RpcRequestHeaderProto header) } /** - * Process an RPC Request + * Process an RPC Request * - the connection headers and context must have been already read. * - Based on the rpcKind, decode the rpcRequest. * - A successfully decoded RpcCall will be deposited in RPC-Q and @@ -2922,12 +2923,12 @@ private void checkRpcHeaders(RpcRequestHeaderProto header) private void processRpcRequest(RpcRequestHeaderProto header, RpcWritable.Buffer buffer) throws RpcServerException, InterruptedException { - Class rpcRequestClass = + Class rpcRequestClass = getRpcRequestWrapper(header.getRpcKind()); if (rpcRequestClass == null) { - LOG.warn("Unknown rpc kind " + header.getRpcKind() + + LOG.warn("Unknown rpc kind " + header.getRpcKind() + " from client " + getHostAddress()); - final String err = "Unknown rpc kind in rpc header" + + final String err = "Unknown rpc kind in rpc header" + header.getRpcKind(); throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err); @@ -2975,51 +2976,64 @@ private void processRpcRequest(RpcRequestHeaderProto header, .build(); } - RpcCall call = new RpcCall(this, header.getCallId(), - header.getRetryCount(), rpcRequest, - ProtoUtil.convert(header.getRpcKind()), - header.getClientId().toByteArray(), span, callerContext); - - // Save the priority level assignment by the scheduler - call.setPriorityLevel(callQueue.getPriorityLevel(call)); - call.markCallCoordinated(false); - if(alignmentContext != null && call.rpcRequest != null && - (call.rpcRequest instanceof ProtobufRpcEngine2.RpcProtobufRequest)) { - // if call.rpcRequest is not RpcProtobufRequest, will skip the following - // step and treat the call as uncoordinated. As currently only certain - // ClientProtocol methods request made through RPC protobuf needs to be - // coordinated. - String methodName; - String protoName; - ProtobufRpcEngine2.RpcProtobufRequest req = - (ProtobufRpcEngine2.RpcProtobufRequest) call.rpcRequest; - try { - methodName = req.getRequestHeader().getMethodName(); - protoName = req.getRequestHeader().getDeclaringClassProtocolName(); - if (alignmentContext.isCoordinatedCall(protoName, methodName)) { - call.markCallCoordinated(true); - long stateId; - stateId = alignmentContext.receiveRequestState( - header, getMaxIdleTime()); - call.setClientStateId(stateId); - if (header.hasRouterFederatedState()) { - call.setFederatedNamespaceState(header.getRouterFederatedState()); + // Set AuthorizationContext for this thread if present + boolean authzSet = false; + try { + if (header.hasAuthorizationHeader()) { + AuthorizationContext.setCurrentAuthorizationHeader(header.getAuthorizationHeader().toByteArray()); + authzSet = true; + } + + RpcCall call = new RpcCall(this, header.getCallId(), + header.getRetryCount(), rpcRequest, + ProtoUtil.convert(header.getRpcKind()), + header.getClientId().toByteArray(), span, callerContext); + + // Save the priority level assignment by the scheduler + call.setPriorityLevel(callQueue.getPriorityLevel(call)); + call.markCallCoordinated(false); + if (alignmentContext != null && call.rpcRequest != null && + (call.rpcRequest instanceof ProtobufRpcEngine2.RpcProtobufRequest)) { + // if call.rpcRequest is not RpcProtobufRequest, will skip the following + // step and treat the call as uncoordinated. As currently only certain + // ClientProtocol methods request made through RPC protobuf needs to be + // coordinated. + String methodName; + String protoName; + ProtobufRpcEngine2.RpcProtobufRequest req = + (ProtobufRpcEngine2.RpcProtobufRequest) call.rpcRequest; + try { + methodName = req.getRequestHeader().getMethodName(); + protoName = req.getRequestHeader().getDeclaringClassProtocolName(); + if (alignmentContext.isCoordinatedCall(protoName, methodName)) { + call.markCallCoordinated(true); + long stateId; + stateId = alignmentContext.receiveRequestState( + header, getMaxIdleTime()); + call.setClientStateId(stateId); + if (header.hasRouterFederatedState()) { + call.setFederatedNamespaceState(header.getRouterFederatedState()); + } } + } catch (IOException ioe) { + throw new RpcServerException("Processing RPC request caught ", ioe); } - } catch (IOException ioe) { - throw new RpcServerException("Processing RPC request caught ", ioe); } - } - try { - internalQueueCall(call); - } catch (RpcServerException rse) { - throw rse; - } catch (IOException ioe) { - throw new FatalRpcServerException( - RpcErrorCodeProto.ERROR_RPC_SERVER, ioe); + try { + internalQueueCall(call); + } catch (RpcServerException rse) { + throw rse; + } catch (IOException ioe) { + throw new FatalRpcServerException( + RpcErrorCodeProto.ERROR_RPC_SERVER, ioe); + } + incRpcCount(); // Increment the rpc count + } finally { + if (authzSet) { + AuthorizationContext.clear(); + } } - incRpcCount(); // Increment the rpc count } /** @@ -3029,7 +3043,7 @@ private void processRpcRequest(RpcRequestHeaderProto header, * @param buffer - stream to request payload * @throws RpcServerException - setup failed due to SASL * negotiation failure, premature or invalid connection context, - * or other state errors. This exception needs to be sent to the + * or other state errors. This exception needs to be sent to the * client. * @throws IOException - failed to send a response back to the client * @throws InterruptedException @@ -3062,7 +3076,7 @@ private void processRpcOutOfBandRequest(RpcRequestHeaderProto header, RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, "Unknown out of band call #" + callId); } - } + } /** * Authorize proxy users to access this server @@ -3090,9 +3104,9 @@ private void authorizeConnection() throws RpcServerException { RpcErrorCodeProto.FATAL_UNAUTHORIZED, ae); } } - + /** - * Decode the a protobuf from the given input stream + * Decode the a protobuf from the given input stream * @return Message - decoded protobuf * @throws RpcServerException - deserialization failed */ @@ -3313,33 +3327,33 @@ void logException(Logger logger, Throwable e, Call call) { logger.info(logMsg, e); } } - + protected Server(String bindAddress, int port, - Class paramClass, int handlerCount, + Class paramClass, int handlerCount, Configuration conf) - throws IOException + throws IOException { this(bindAddress, port, paramClass, handlerCount, -1, -1, conf, Integer .toString(port), null, null); } - + protected Server(String bindAddress, int port, Class rpcRequestClass, int handlerCount, int numReaders, int queueSizePerHandler, Configuration conf, String serverName, SecretManager secretManager) throws IOException { - this(bindAddress, port, rpcRequestClass, handlerCount, numReaders, + this(bindAddress, port, rpcRequestClass, handlerCount, numReaders, queueSizePerHandler, conf, serverName, secretManager, null); } - - /** + + /** * Constructs a server listening on the named port and address. Parameters passed must * be of the named class. The handlerCount determines * the number of handler threads that will be used to process calls. * If queueSizePerHandler or numReaders are not -1 they will be used instead of parameters * from configuration. Otherwise the configuration will be picked up. - * - * If rpcRequestClass is null then the rpcRequestClass must have been + * + * If rpcRequestClass is null then the rpcRequestClass must have been * registered via {@link #registerProtocolEngine(RPC.RpcKind, * Class, RPC.RpcInvoker)} * This parameter has been retained for compatibility with existing tests @@ -3368,7 +3382,7 @@ protected Server(String bindAddress, int port, this.conf = conf; this.portRangeConfig = portRangeConfig; this.port = port; - this.rpcRequestClass = rpcRequestClass; + this.rpcRequestClass = rpcRequestClass; this.handlerCount = handlerCount; this.socketSendBufferSize = 0; this.serverName = serverName; @@ -3380,7 +3394,7 @@ protected Server(String bindAddress, int port, } else { this.maxQueueSize = handlerCount * conf.getInt( CommonConfigurationKeys.IPC_SERVER_HANDLER_QUEUE_SIZE_KEY, - CommonConfigurationKeys.IPC_SERVER_HANDLER_QUEUE_SIZE_DEFAULT); + CommonConfigurationKeys.IPC_SERVER_HANDLER_QUEUE_SIZE_DEFAULT); } this.maxRespSize = conf.getInt( CommonConfigurationKeys.IPC_SERVER_RPC_MAX_RESPONSE_SIZE_KEY, @@ -3405,14 +3419,14 @@ protected Server(String bindAddress, int port, maxQueueSize, prefix, conf); this.secretManager = (SecretManager) secretManager; - this.authorize = - conf.getBoolean(CommonConfigurationKeys.HADOOP_SECURITY_AUTHORIZATION, + this.authorize = + conf.getBoolean(CommonConfigurationKeys.HADOOP_SECURITY_AUTHORIZATION, false); // configure supported authentications this.enabledAuthMethods = getAuthMethods(secretManager, conf); this.negotiateResponse = buildNegotiateResponse(enabledAuthMethods); - + // Start the listener here and let it bind to the port listener = new Listener(port); // set the server port to the default listener port. @@ -3438,12 +3452,12 @@ protected Server(String bindAddress, int port, // Create the responder here responder = new Responder(); - + if (secretManager != null || UserGroupInformation.isSecurityEnabled()) { SaslRpcServer.init(conf); saslPropsResolver = SaslPropertiesResolver.getInstance(conf); } - + this.exceptionsHandler.addTerseLoggingExceptions(StandbyException.class); this.exceptionsHandler.addTerseLoggingExceptions( HealthCheckFailedException.class); @@ -3504,7 +3518,7 @@ private RpcSaslProto buildNegotiateResponse(List authMethods) } else { negotiateBuilder.setState(SaslState.NEGOTIATE); for (AuthMethod authMethod : authMethods) { - SaslRpcServer saslRpcServer = new SaslRpcServer(authMethod); + SaslRpcServer saslRpcServer = new SaslRpcServer(authMethod); SaslAuth.Builder builder = negotiateBuilder.addAuthsBuilder() .setMethod(authMethod.toString()) .setMechanism(saslRpcServer.mechanism); @@ -3525,31 +3539,31 @@ private RpcSaslProto buildNegotiateResponse(List authMethods) private List getAuthMethods(SecretManager secretManager, Configuration conf) { AuthenticationMethod confAuthenticationMethod = - SecurityUtil.getAuthenticationMethod(conf); + SecurityUtil.getAuthenticationMethod(conf); List authMethods = new ArrayList(); if (confAuthenticationMethod == AuthenticationMethod.TOKEN) { if (secretManager == null) { throw new IllegalArgumentException(AuthenticationMethod.TOKEN + " authentication requires a secret manager"); - } + } } else if (secretManager != null) { LOG.debug("{} authentication enabled for secret manager", AuthenticationMethod.TOKEN); // most preferred, go to the front of the line! authMethods.add(AuthenticationMethod.TOKEN.getAuthMethod()); } - authMethods.add(confAuthenticationMethod.getAuthMethod()); - + authMethods.add(confAuthenticationMethod.getAuthMethod()); + LOG.debug("Server accepts auth methods:{}", authMethods); return authMethods; } - + private void closeConnection(Connection connection) { connectionManager.close(connection); } /** * Setup response for the IPC Call. - * + * * @param call {@link Call} to which we are setting up the response * @param status of the IPC call * @param rv return value for the IPC Call, if the call was successful @@ -3665,11 +3679,11 @@ private static int getDelimitedLength(Message message) { } /** - * Setup response for the IPC Call on Fatal Error from a + * Setup response for the IPC Call on Fatal Error from a * client that is using old version of Hadoop. * The response is serialized using the previous protocol's response * layout. - * + * * @param response buffer to serialize the response into * @param call {@link Call} to which we are setting up the response * @param rv return value for the IPC Call, if the call was successful @@ -3677,9 +3691,9 @@ private static int getDelimitedLength(Message message) { * @param error error message, if the call failed * @throws IOException */ - private void setupResponseOldVersionFatal(ByteArrayOutputStream response, + private void setupResponseOldVersionFatal(ByteArrayOutputStream response, RpcCall call, - Writable rv, String errorClass, String error) + Writable rv, String errorClass, String error) throws IOException { final int OLD_VERSION_FATAL_STATUS = -1; response.reset(); @@ -3712,11 +3726,11 @@ private void wrapWithSasl(RpcCall call) throws IOException { setupResponse(call, saslHeader, RpcWritable.wrap(saslMessage)); } } - + Configuration getConf() { return conf; } - + /** * Sets the socket buffer size used for responding to RPCs. * @param size input size. @@ -3738,7 +3752,7 @@ public synchronized void start() { } handlers = new Handler[handlerCount]; - + for (int i = 0; i < handlerCount; i++) { handlers[i] = new Handler(i); handlers[i].start(); @@ -3821,9 +3835,9 @@ public synchronized Set getAuxiliaryListenerAddresses() { } return allAddrs; } - - /** - * Called for each call. + + /** + * Called for each call. * @deprecated Use {@link #call(RPC.RpcKind, String, * Writable, long)} instead * @param param input param. @@ -3835,7 +3849,7 @@ public synchronized Set getAuxiliaryListenerAddresses() { public Writable call(Writable param, long receiveTime) throws Exception { return call(RPC.RpcKind.RPC_BUILTIN, null, param, receiveTime); } - + /** * Called for each call. * @param rpcKind input rpcKind. @@ -3847,10 +3861,10 @@ public Writable call(Writable param, long receiveTime) throws Exception { */ public abstract Writable call(RPC.RpcKind rpcKind, String protocol, Writable param, long receiveTime) throws Exception; - + /** * Authorize the incoming client connection. - * + * * @param user client user * @param protocolName - the protocol * @param addr InetAddress of incoming connection @@ -3866,13 +3880,13 @@ private void authorize(UserGroupInformation user, String protocolName, try { protocol = getProtocolClass(protocolName, getConf()); } catch (ClassNotFoundException cfne) { - throw new AuthorizationException("Unknown protocol: " + + throw new AuthorizationException("Unknown protocol: " + protocolName); } serviceAuthorizationManager.authorize(user, protocol, getConf(), addr); } } - + /** * Get the port on which the IPC Server is listening for incoming connections. * This could be an ephemeral port too, in which case we return the real @@ -3882,7 +3896,7 @@ private void authorize(UserGroupInformation user, String protocolName, public int getPort() { return port; } - + /** * The number of open RPC conections * @return the number of open rpc connections @@ -3957,25 +3971,25 @@ public int getNumReaders() { } /** - * When the read or write buffer size is larger than this limit, i/o will be + * When the read or write buffer size is larger than this limit, i/o will be * done in chunks of this size. Most RPC requests and responses would be * be smaller. */ private static int NIO_BUFFER_LIMIT = 8*1024; //should not be more than 64KB. - + /** * This is a wrapper around {@link WritableByteChannel#write(ByteBuffer)}. - * If the amount of data is large, it writes to channel in smaller chunks. - * This is to avoid jdk from creating many direct buffers as the size of + * If the amount of data is large, it writes to channel in smaller chunks. + * This is to avoid jdk from creating many direct buffers as the size of * buffer increases. This also minimizes extra copies in NIO layer - * as a result of multiple write operations required to write a large - * buffer. + * as a result of multiple write operations required to write a large + * buffer. * * @see WritableByteChannel#write(ByteBuffer) */ - private int channelWrite(WritableByteChannel channel, + private int channelWrite(WritableByteChannel channel, ByteBuffer buffer) throws IOException { - + int count = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? channel.write(buffer) : channelIO(null, channel, buffer); if (count > 0) { @@ -3983,19 +3997,19 @@ private int channelWrite(WritableByteChannel channel, } return count; } - - + + /** * This is a wrapper around {@link ReadableByteChannel#read(ByteBuffer)}. - * If the amount of data is large, it writes to channel in smaller chunks. - * This is to avoid jdk from creating many direct buffers as the size of + * If the amount of data is large, it writes to channel in smaller chunks. + * This is to avoid jdk from creating many direct buffers as the size of * ByteBuffer increases. There should not be any performance degredation. - * + * * @see ReadableByteChannel#read(ByteBuffer) */ - private int channelRead(ReadableByteChannel channel, + private int channelRead(ReadableByteChannel channel, ByteBuffer buffer) throws IOException { - + int count = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? channel.read(buffer) : channelIO(channel, null, buffer); if (count > 0) { @@ -4003,43 +4017,43 @@ private int channelRead(ReadableByteChannel channel, } return count; } - + /** * Helper for {@link #channelRead(ReadableByteChannel, ByteBuffer)} * and {@link #channelWrite(WritableByteChannel, ByteBuffer)}. Only * one of readCh or writeCh should be non-null. - * + * * @see #channelRead(ReadableByteChannel, ByteBuffer) * @see #channelWrite(WritableByteChannel, ByteBuffer) */ - private static int channelIO(ReadableByteChannel readCh, + private static int channelIO(ReadableByteChannel readCh, WritableByteChannel writeCh, ByteBuffer buf) throws IOException { - + int originalLimit = buf.limit(); int initialRemaining = buf.remaining(); int ret = 0; - + while (buf.remaining() > 0) { try { int ioSize = Math.min(buf.remaining(), NIO_BUFFER_LIMIT); buf.limit(buf.position() + ioSize); - - ret = (readCh == null) ? writeCh.write(buf) : readCh.read(buf); - + + ret = (readCh == null) ? writeCh.write(buf) : readCh.read(buf); + if (ret < ioSize) { break; } } finally { - buf.limit(originalLimit); + buf.limit(originalLimit); } } - int nBytes = initialRemaining - buf.remaining(); + int nBytes = initialRemaining - buf.remaining(); return (nBytes > 0) ? nBytes : ret; } - + private class ConnectionManager { final private AtomicInteger count = new AtomicInteger(); final private AtomicLong droppedConnections = new AtomicLong(); @@ -4054,7 +4068,7 @@ private class ConnectionManager { final private int maxIdleTime; final private int maxIdleToClose; final private int maxConnections; - + ConnectionManager() { this.idleScanTimer = new Timer( "IPC Server idle connection scanner for port " + getPort(), true); @@ -4088,7 +4102,7 @@ private boolean add(Connection connection) { } return added; } - + private boolean remove(Connection connection) { boolean removed = connections.remove(connection); if (removed) { @@ -4159,7 +4173,7 @@ Connection register(SocketChannel channel, int ingressPort, connection, size(), callQueue.size()); return connection; } - + boolean close(Connection connection) { boolean exists = remove(connection); if (exists) { @@ -4177,7 +4191,7 @@ boolean close(Connection connection) { } return exists; } - + // synch'ed to avoid explicit invocation upon OOM from colliding with // timer task firing synchronized void closeIdle(boolean scanAll) { @@ -4200,7 +4214,7 @@ synchronized void closeIdle(boolean scanAll) { } } } - + void closeAll() { // use a copy of the connections to be absolutely sure the concurrent // iterator doesn't miss a connection @@ -4208,15 +4222,15 @@ void closeAll() { close(connection); } } - + void startIdleScan() { scheduleIdleScanTask(); } - + void stopIdleScan() { idleScanTimer.cancel(); } - + private void scheduleIdleScanTask() { if (!running) { return; diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/AuthorizationContext.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/AuthorizationContext.java new file mode 100644 index 0000000000000..6026167c9968b --- /dev/null +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/AuthorizationContext.java @@ -0,0 +1,22 @@ +package org.apache.hadoop.security; + +/** + * Utility for managing a thread-local authorization header for RPC calls. + */ +public final class AuthorizationContext { + private static final ThreadLocal AUTH_HEADER = new ThreadLocal<>(); + + private AuthorizationContext() {} + + public static void setCurrentAuthorizationHeader(byte[] header) { + AUTH_HEADER.set(header); + } + + public static byte[] getCurrentAuthorizationHeader() { + return AUTH_HEADER.get(); + } + + public static void clear() { + AUTH_HEADER.remove(); + } +} \ No newline at end of file diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/ProtoUtil.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/ProtoUtil.java index 883c19c5e7750..307be15db6f34 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/ProtoUtil.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/ProtoUtil.java @@ -32,6 +32,7 @@ import org.apache.hadoop.tracing.Span; import org.apache.hadoop.tracing.Tracer; import org.apache.hadoop.tracing.TraceUtils; +import org.apache.hadoop.security.AuthorizationContext; import org.apache.hadoop.thirdparty.protobuf.ByteString; @@ -203,6 +204,12 @@ public static RpcRequestHeaderProto makeRpcRequestHeader(RPC.RpcKind rpcKind, result.setCallerContext(contextBuilder); } + // Add authorization header if present + byte[] authzHeader = AuthorizationContext.getCurrentAuthorizationHeader(); + if (authzHeader != null) { + result.setAuthorizationHeader(ByteString.copyFrom(authzHeader)); + } + // Add alignment context if it is not null if (alignmentContext != null) { alignmentContext.updateRequestState(result); diff --git a/hadoop-common-project/hadoop-common/src/main/proto/RpcHeader.proto b/hadoop-common-project/hadoop-common/src/main/proto/RpcHeader.proto index d9becf722e982..0eba4eda9e5a5 100644 --- a/hadoop-common-project/hadoop-common/src/main/proto/RpcHeader.proto +++ b/hadoop-common-project/hadoop-common/src/main/proto/RpcHeader.proto @@ -90,6 +90,8 @@ message RpcRequestHeaderProto { // the header for the RpcRequest optional sint32 retryCount = 5 [default = -1]; optional RPCTraceInfoProto traceInfo = 6; // tracing info optional RPCCallerContextProto callerContext = 7; // call context + // Authorization header for passing opaque credentials or tokens + optional bytes authorizationHeader = 10; optional int64 stateId = 8; // The last seen Global State ID // Alignment context info for use with routers. // The client should not interpret these bytes, but only forward bytes diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/security/TestAuthorizationContext.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/security/TestAuthorizationContext.java new file mode 100644 index 0000000000000..5d6a649b15e94 --- /dev/null +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/security/TestAuthorizationContext.java @@ -0,0 +1,52 @@ +package org.apache.hadoop.security; + +import org.junit.Assert; +import org.junit.Test; + +public class TestAuthorizationContext { + + @Test + public void testSetAndGetAuthorizationHeader() { + byte[] header = "my-auth-header".getBytes(); + AuthorizationContext.setCurrentAuthorizationHeader(header); + Assert.assertArrayEquals(header, AuthorizationContext.getCurrentAuthorizationHeader()); + AuthorizationContext.clear(); + } + + @Test + public void testClearAuthorizationHeader() { + byte[] header = "clear-me".getBytes(); + AuthorizationContext.setCurrentAuthorizationHeader(header); + AuthorizationContext.clear(); + Assert.assertNull(AuthorizationContext.getCurrentAuthorizationHeader()); + } + + @Test + public void testThreadLocalIsolation() throws Exception { + byte[] mainHeader = "main-thread".getBytes(); + AuthorizationContext.setCurrentAuthorizationHeader(mainHeader); + Thread t = new Thread(() -> { + Assert.assertNull(AuthorizationContext.getCurrentAuthorizationHeader()); + byte[] threadHeader = "other-thread".getBytes(); + AuthorizationContext.setCurrentAuthorizationHeader(threadHeader); + Assert.assertArrayEquals(threadHeader, AuthorizationContext.getCurrentAuthorizationHeader()); + AuthorizationContext.clear(); + Assert.assertNull(AuthorizationContext.getCurrentAuthorizationHeader()); + }); + t.start(); + t.join(); + // Main thread should still have its header + Assert.assertArrayEquals(mainHeader, AuthorizationContext.getCurrentAuthorizationHeader()); + AuthorizationContext.clear(); + } + + @Test + public void testNullAndEmptyHeader() { + AuthorizationContext.setCurrentAuthorizationHeader(null); + Assert.assertNull(AuthorizationContext.getCurrentAuthorizationHeader()); + byte[] empty = new byte[0]; + AuthorizationContext.setCurrentAuthorizationHeader(empty); + Assert.assertArrayEquals(empty, AuthorizationContext.getCurrentAuthorizationHeader()); + AuthorizationContext.clear(); + } +} \ No newline at end of file diff --git a/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/server/namenode/TestAuthorizationHeaderPropagation.java b/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/server/namenode/TestAuthorizationHeaderPropagation.java new file mode 100644 index 0000000000000..f5eb80673fac3 --- /dev/null +++ b/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/server/namenode/TestAuthorizationHeaderPropagation.java @@ -0,0 +1,64 @@ +package org.apache.hadoop.hdfs.server.namenode; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.hdfs.HdfsConfiguration; +import org.apache.hadoop.hdfs.MiniDFSCluster; +import org.apache.hadoop.security.AuthorizationContext; +import org.junit.Test; + +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.apache.hadoop.hdfs.DFSConfigKeys.DFS_NAMENODE_AUDIT_LOGGERS_KEY; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertNull; + +public class TestAuthorizationHeaderPropagation { + + public static class HeaderCapturingAuditLogger implements AuditLogger { + public static final List capturedHeaders = new ArrayList<>(); + @Override + public void initialize(Configuration conf) {} + @Override + public void logAuditEvent(boolean succeeded, String userName, InetAddress addr, + String cmd, String src, String dst, FileStatus stat) { + byte[] header = AuthorizationContext.getCurrentAuthorizationHeader(); + capturedHeaders.add(header == null ? null : Arrays.copyOf(header, header.length)); + } + } + + @Test + public void testAuthorizationHeaderPerRpc() throws Exception { + Configuration conf = new HdfsConfiguration(); + conf.set(DFS_NAMENODE_AUDIT_LOGGERS_KEY, HeaderCapturingAuditLogger.class.getName()); + MiniDFSCluster cluster = new MiniDFSCluster.Builder(conf).build(); + try { + cluster.waitClusterUp(); + HeaderCapturingAuditLogger.capturedHeaders.clear(); + FileSystem fs = cluster.getFileSystem(); + // First RPC with header1 + byte[] header1 = "header-one".getBytes(); + AuthorizationContext.setCurrentAuthorizationHeader(header1); + fs.mkdirs(new Path("/authz1")); + AuthorizationContext.clear(); + // Second RPC with header2 + byte[] header2 = "header-two".getBytes(); + AuthorizationContext.setCurrentAuthorizationHeader(header2); + fs.mkdirs(new Path("/authz2")); + AuthorizationContext.clear(); + // Third RPC with no header + fs.mkdirs(new Path("/authz3")); + // Now assert + assertArrayEquals(header1, HeaderCapturingAuditLogger.capturedHeaders.get(0)); + assertArrayEquals(header2, HeaderCapturingAuditLogger.capturedHeaders.get(1)); + assertNull(HeaderCapturingAuditLogger.capturedHeaders.get(2)); + } finally { + cluster.shutdown(); + } + } +} \ No newline at end of file From a35593184ea07a202c63bc3b84557134964ff5ad Mon Sep 17 00:00:00 2001 From: Tom McCormick Date: Wed, 16 Jul 2025 11:23:39 -0400 Subject: [PATCH 2/6] Add license to new files --- .../hadoop/security/AuthorizationContext.java | 17 +++++++++++++++++ .../security/TestAuthorizationContext.java | 17 +++++++++++++++++ .../TestAuthorizationHeaderPropagation.java | 17 +++++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/AuthorizationContext.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/AuthorizationContext.java index 6026167c9968b..4b8793975c4f5 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/AuthorizationContext.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/AuthorizationContext.java @@ -1,3 +1,20 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.hadoop.security; /** diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/security/TestAuthorizationContext.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/security/TestAuthorizationContext.java index 5d6a649b15e94..2ccda42e98e69 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/security/TestAuthorizationContext.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/security/TestAuthorizationContext.java @@ -1,3 +1,20 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.hadoop.security; import org.junit.Assert; diff --git a/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/server/namenode/TestAuthorizationHeaderPropagation.java b/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/server/namenode/TestAuthorizationHeaderPropagation.java index f5eb80673fac3..9bdc4fefb959c 100644 --- a/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/server/namenode/TestAuthorizationHeaderPropagation.java +++ b/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/server/namenode/TestAuthorizationHeaderPropagation.java @@ -1,3 +1,20 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.hadoop.hdfs.server.namenode; import org.apache.hadoop.conf.Configuration; From 90cc57d2268ed67323ddf19315cf6c563df15b1f Mon Sep 17 00:00:00 2001 From: Tom McCormick Date: Wed, 16 Jul 2025 12:48:38 -0400 Subject: [PATCH 3/6] fix ordering of new proto field --- .../src/main/proto/RpcHeader.proto | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/hadoop-common-project/hadoop-common/src/main/proto/RpcHeader.proto b/hadoop-common-project/hadoop-common/src/main/proto/RpcHeader.proto index 0eba4eda9e5a5..71a75f11b9097 100644 --- a/hadoop-common-project/hadoop-common/src/main/proto/RpcHeader.proto +++ b/hadoop-common-project/hadoop-common/src/main/proto/RpcHeader.proto @@ -29,7 +29,7 @@ package hadoop.common; /** * This is the rpc request header. It is sent with every rpc call. - * + * * The format of RPC call is as follows: * +--------------------------------------------------------------+ * | Rpc length in bytes (4 bytes int) sum of next two parts | @@ -47,12 +47,12 @@ package hadoop.common; */ enum RpcKindProto { RPC_BUILTIN = 0; // Used for built in calls by tests - RPC_WRITABLE = 1; // Use WritableRpcEngine + RPC_WRITABLE = 1; // Use WritableRpcEngine RPC_PROTOCOL_BUFFER = 2; // Use ProtobufRpcEngine } - + /** * Used to pass through the information necessary to continue * a trace after an RPC is made. All we need is the traceid @@ -90,13 +90,13 @@ message RpcRequestHeaderProto { // the header for the RpcRequest optional sint32 retryCount = 5 [default = -1]; optional RPCTraceInfoProto traceInfo = 6; // tracing info optional RPCCallerContextProto callerContext = 7; // call context - // Authorization header for passing opaque credentials or tokens - optional bytes authorizationHeader = 10; optional int64 stateId = 8; // The last seen Global State ID // Alignment context info for use with routers. // The client should not interpret these bytes, but only forward bytes // received from RpcResponseHeaderProto.routerFederatedState. optional bytes routerFederatedState = 9; + // Authorization header for passing opaque credentials or tokens + optional bytes authorizationHeader = 10; } @@ -117,12 +117,12 @@ message RpcRequestHeaderProto { // the header for the RpcRequest * | The rpc response header contains the necessary info | * +------------------------------------------------------------------+ * - * Note that rpc response header is also used when connection setup fails. + * Note that rpc response header is also used when connection setup fails. * Ie the response looks like a rpc response with a fake callId. */ message RpcResponseHeaderProto { /** - * + * * RpcStastus - success or failure * The reponseHeader's errDetail, exceptionClassName and errMsg contains * further details on the error @@ -178,7 +178,7 @@ message RpcSaslProto { RESPONSE = 4; WRAP = 5; } - + message SaslAuth { required string method = 1; required string mechanism = 2; @@ -187,7 +187,7 @@ message RpcSaslProto { optional bytes challenge = 5; } - optional uint32 version = 1; + optional uint32 version = 1; required SaslState state = 2; optional bytes token = 3; repeated SaslAuth auths = 4; From 49a8b721a3a45a6a16463792ea424ff3876600c7 Mon Sep 17 00:00:00 2001 From: Tom McCormick Date: Wed, 16 Jul 2025 14:45:24 -0400 Subject: [PATCH 4/6] Pass auth header to rpc call and fix unit tests to be run --- .../java/org/apache/hadoop/ipc/Server.java | 22 ++++++++++++++++--- .../security/TestAuthorizationContext.java | 20 ++++++++--------- .../TestAuthorizationHeaderPropagation.java | 6 ++--- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java index f76212e3c76b8..c37210bf164f6 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java @@ -1005,6 +1005,7 @@ public static class Call implements Schedulable, final byte[] clientId; private final Span span; // the trace span on the server side private final CallerContext callerContext; // the call context + private final byte[] authHeader; // the auth header private boolean deferredResponse = false; private int priorityLevel; // the priority level assigned by scheduler, 0 by default @@ -1036,6 +1037,11 @@ public Call(int id, int retryCount, Void ignore1, Void ignore2, Call(int id, int retryCount, RPC.RpcKind kind, byte[] clientId, Span span, CallerContext callerContext) { + this(id, retryCount, kind, clientId, span, callerContext, null); + } + + Call(int id, int retryCount, RPC.RpcKind kind, byte[] clientId, + Span span, CallerContext callerContext, byte[] authHeader) { this.callId = id; this.retryCount = retryCount; this.timestampNanos = Time.monotonicNowNanos(); @@ -1044,6 +1050,7 @@ public Call(int id, int retryCount, Void ignore1, Void ignore2, this.clientId = clientId; this.span = span; this.callerContext = callerContext; + this.authHeader = authHeader; this.clientStateId = Long.MIN_VALUE; this.isCallCoordinated = false; } @@ -1244,7 +1251,14 @@ private class RpcCall extends Call { RpcCall(Connection connection, int id, int retryCount, Writable param, RPC.RpcKind kind, byte[] clientId, Span span, CallerContext context) { - super(id, retryCount, kind, clientId, span, context); + this(connection, id, retryCount, param, kind, clientId, + span, context, new byte[0]); + } + + RpcCall(Connection connection, int id, int retryCount, + Writable param, RPC.RpcKind kind, byte[] clientId, + Span span, CallerContext context, byte[] authHeader) { + super(id, retryCount, kind, clientId, span, context, authHeader); this.connection = connection; this.rpcRequest = param; } @@ -2977,17 +2991,18 @@ private void processRpcRequest(RpcRequestHeaderProto header, } // Set AuthorizationContext for this thread if present + byte[] authHeader = null; boolean authzSet = false; try { if (header.hasAuthorizationHeader()) { - AuthorizationContext.setCurrentAuthorizationHeader(header.getAuthorizationHeader().toByteArray()); + authHeader = header.getAuthorizationHeader().toByteArray(); authzSet = true; } RpcCall call = new RpcCall(this, header.getCallId(), header.getRetryCount(), rpcRequest, ProtoUtil.convert(header.getRpcKind()), - header.getClientId().toByteArray(), span, callerContext); + header.getClientId().toByteArray(), span, callerContext, authHeader); // Save the priority level assignment by the scheduler call.setPriorityLevel(callQueue.getPriorityLevel(call)); @@ -3259,6 +3274,7 @@ public void run() { } // always update the current call context CallerContext.setCurrent(call.callerContext); + AuthorizationContext.setCurrentAuthorizationHeader(call.authHeader); UserGroupInformation remoteUser = call.getRemoteUser(); connDropped = !call.isOpen(); if (remoteUser != null) { diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/security/TestAuthorizationContext.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/security/TestAuthorizationContext.java index 2ccda42e98e69..fe6bc4f58de93 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/security/TestAuthorizationContext.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/security/TestAuthorizationContext.java @@ -17,8 +17,8 @@ */ package org.apache.hadoop.security; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; public class TestAuthorizationContext { @@ -26,7 +26,7 @@ public class TestAuthorizationContext { public void testSetAndGetAuthorizationHeader() { byte[] header = "my-auth-header".getBytes(); AuthorizationContext.setCurrentAuthorizationHeader(header); - Assert.assertArrayEquals(header, AuthorizationContext.getCurrentAuthorizationHeader()); + Assertions.assertArrayEquals(header, AuthorizationContext.getCurrentAuthorizationHeader()); AuthorizationContext.clear(); } @@ -35,7 +35,7 @@ public void testClearAuthorizationHeader() { byte[] header = "clear-me".getBytes(); AuthorizationContext.setCurrentAuthorizationHeader(header); AuthorizationContext.clear(); - Assert.assertNull(AuthorizationContext.getCurrentAuthorizationHeader()); + Assertions.assertNull(AuthorizationContext.getCurrentAuthorizationHeader()); } @Test @@ -43,27 +43,27 @@ public void testThreadLocalIsolation() throws Exception { byte[] mainHeader = "main-thread".getBytes(); AuthorizationContext.setCurrentAuthorizationHeader(mainHeader); Thread t = new Thread(() -> { - Assert.assertNull(AuthorizationContext.getCurrentAuthorizationHeader()); + Assertions.assertNull(AuthorizationContext.getCurrentAuthorizationHeader()); byte[] threadHeader = "other-thread".getBytes(); AuthorizationContext.setCurrentAuthorizationHeader(threadHeader); - Assert.assertArrayEquals(threadHeader, AuthorizationContext.getCurrentAuthorizationHeader()); + Assertions.assertArrayEquals(threadHeader, AuthorizationContext.getCurrentAuthorizationHeader()); AuthorizationContext.clear(); - Assert.assertNull(AuthorizationContext.getCurrentAuthorizationHeader()); + Assertions.assertNull(AuthorizationContext.getCurrentAuthorizationHeader()); }); t.start(); t.join(); // Main thread should still have its header - Assert.assertArrayEquals(mainHeader, AuthorizationContext.getCurrentAuthorizationHeader()); + Assertions.assertArrayEquals(mainHeader, AuthorizationContext.getCurrentAuthorizationHeader()); AuthorizationContext.clear(); } @Test public void testNullAndEmptyHeader() { AuthorizationContext.setCurrentAuthorizationHeader(null); - Assert.assertNull(AuthorizationContext.getCurrentAuthorizationHeader()); + Assertions.assertNull(AuthorizationContext.getCurrentAuthorizationHeader()); byte[] empty = new byte[0]; AuthorizationContext.setCurrentAuthorizationHeader(empty); - Assert.assertArrayEquals(empty, AuthorizationContext.getCurrentAuthorizationHeader()); + Assertions.assertArrayEquals(empty, AuthorizationContext.getCurrentAuthorizationHeader()); AuthorizationContext.clear(); } } \ No newline at end of file diff --git a/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/server/namenode/TestAuthorizationHeaderPropagation.java b/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/server/namenode/TestAuthorizationHeaderPropagation.java index 9bdc4fefb959c..351c1f814f885 100644 --- a/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/server/namenode/TestAuthorizationHeaderPropagation.java +++ b/hadoop-hdfs-project/hadoop-hdfs/src/test/java/org/apache/hadoop/hdfs/server/namenode/TestAuthorizationHeaderPropagation.java @@ -24,7 +24,7 @@ import org.apache.hadoop.hdfs.HdfsConfiguration; import org.apache.hadoop.hdfs.MiniDFSCluster; import org.apache.hadoop.security.AuthorizationContext; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.net.InetAddress; import java.util.ArrayList; @@ -32,8 +32,8 @@ import java.util.List; import static org.apache.hadoop.hdfs.DFSConfigKeys.DFS_NAMENODE_AUDIT_LOGGERS_KEY; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertNull; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertNull; public class TestAuthorizationHeaderPropagation { From 8d304ff9472fdc637d620af518797c60a832ab85 Mon Sep 17 00:00:00 2001 From: Tom McCormick Date: Mon, 21 Jul 2025 20:33:23 -0400 Subject: [PATCH 5/6] remove whitespace --- .../java/org/apache/hadoop/ipc/Server.java | 386 +++++++++--------- .../src/main/proto/RpcHeader.proto | 14 +- 2 files changed, 200 insertions(+), 200 deletions(-) diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java index c37210bf164f6..37a1e8e499240 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java @@ -144,7 +144,7 @@ /** An abstract IPC service. IPC calls take a single {@link Writable} as a * parameter, and return a {@link Writable} as their value. A service runs on * a port and is defined by a parameter class and a value class. - * + * * @see Client */ @Public @@ -209,7 +209,7 @@ static class ExceptionsHandler { /** * Add exception classes for which server won't log stack traces. * Optimized for infrequent invocation. - * @param exceptionClass exception classes + * @param exceptionClass exception classes */ void addTerseLoggingExceptions(Class... exceptionClass) { terseExceptions.addAll(Arrays @@ -240,14 +240,14 @@ boolean isSuppressedLog(Class t) { } - + /** * If the user accidentally sends an HTTP GET to an IPC port, we detect this * and send back a nicer response. */ private static final ByteBuffer HTTP_GET_BYTES = ByteBuffer.wrap( "GET ".getBytes(StandardCharsets.UTF_8)); - + /** * An HTTP response to send back if we detect an HTTP request to our IPC * port. @@ -262,7 +262,7 @@ boolean isSuppressedLog(Class t) { * Initial and max size of response buffer */ static int INITIAL_RESP_BUF_SIZE = 10240; - + static class RpcKindMapValue { final Class rpcRequestWrapperClass; final RpcInvoker rpcInvoker; @@ -271,42 +271,42 @@ static class RpcKindMapValue { RpcInvoker rpcInvoker) { this.rpcInvoker = rpcInvoker; this.rpcRequestWrapperClass = rpcRequestWrapperClass; - } + } } static Map rpcKindMap = new HashMap<>(4); - - + + /** * Register a RPC kind and the class to deserialize the rpc request. - * + * * Called by static initializers of rpcKind Engines * @param rpcKind - input rpcKind. * @param rpcRequestWrapperClass - this class is used to deserialze the * the rpc request. * @param rpcInvoker - use to process the calls on SS. */ - - public static void registerProtocolEngine(RPC.RpcKind rpcKind, + + public static void registerProtocolEngine(RPC.RpcKind rpcKind, Class rpcRequestWrapperClass, RpcInvoker rpcInvoker) { - RpcKindMapValue old = + RpcKindMapValue old = rpcKindMap.put(rpcKind, new RpcKindMapValue(rpcRequestWrapperClass, rpcInvoker)); if (old != null) { rpcKindMap.put(rpcKind, old); throw new IllegalArgumentException("ReRegistration of rpcKind: " + - rpcKind); + rpcKind); } LOG.debug("rpcKind={}, rpcRequestWrapperClass={}, rpcInvoker={}.", rpcKind, rpcRequestWrapperClass, rpcInvoker); } - + public Class getRpcRequestWrapper( RpcKindProto rpcKind) { if (rpcRequestClass != null) return rpcRequestClass; RpcKindMapValue val = rpcKindMap.get(ProtoUtil.convert(rpcKind)); - return (val == null) ? null : val.rpcRequestWrapperClass; + return (val == null) ? null : val.rpcRequestWrapperClass; } protected RpcInvoker getServerRpcInvoker(RPC.RpcKind rpcKind) { @@ -315,22 +315,22 @@ protected RpcInvoker getServerRpcInvoker(RPC.RpcKind rpcKind) { public static RpcInvoker getRpcInvoker(RPC.RpcKind rpcKind) { RpcKindMapValue val = rpcKindMap.get(rpcKind); - return (val == null) ? null : val.rpcInvoker; + return (val == null) ? null : val.rpcInvoker; } - + public static final Logger LOG = LoggerFactory.getLogger(Server.class); public static final Logger AUDITLOG = LoggerFactory.getLogger("SecurityLogger."+Server.class.getName()); private static final String AUTH_FAILED_FOR = "Auth failed for "; private static final String AUTH_SUCCESSFUL_FOR = "Auth successful for "; - + private static final ThreadLocal SERVER = new ThreadLocal(); - private static final Map> PROTOCOL_CACHE = + private static final Map> PROTOCOL_CACHE = new ConcurrentHashMap>(); - - static Class getProtocolClass(String protocolName, Configuration conf) + + static Class getProtocolClass(String protocolName, Configuration conf) throws ClassNotFoundException { Class protocol = PROTOCOL_CACHE.get(protocolName); if (protocol == null) { @@ -339,7 +339,7 @@ static Class getProtocolClass(String protocolName, Configuration conf) } return protocol; } - + /** @return Returns the server instance called under or null. May be called under * {@link #call(Writable, long)} implementations, and under {@link Writable} * methods of paramters and return values. Permits applications to access @@ -347,7 +347,7 @@ static Class getProtocolClass(String protocolName, Configuration conf) public static Server get() { return SERVER.get(); } - + /** This is set to Call object before Handler invokes an RPC and reset * after the call returns. */ @@ -363,14 +363,14 @@ public static ThreadLocal getCurCall() { * Returns the currently active RPC call's sequential ID number. A negative * call ID indicates an invalid value, such as if there is no currently active * RPC call. - * + * * @return int sequential ID number of currently active RPC call */ public static int getCallId() { Call call = CurCall.get(); return call != null ? call.callId : RpcConstants.INVALID_CALL_ID; } - + /** * @return The current active RPC call's retry count. -1 indicates the retry * cache is not supported in the client side. @@ -709,7 +709,7 @@ void updateDeferredMetrics(Call call, String name) { } /** - * A convenience method to bind to a given address and report + * A convenience method to bind to a given address and report * better exceptions if the address is not a valid host. * @param socket the socket to bind * @param address the address to bind to @@ -718,12 +718,12 @@ void updateDeferredMetrics(Call call, String name) { * @throws UnknownHostException if the address isn't a valid host name * @throws IOException other random errors from bind */ - public static void bind(ServerSocket socket, InetSocketAddress address, + public static void bind(ServerSocket socket, InetSocketAddress address, int backlog) throws IOException { bind(socket, address, backlog, null, null); } - public static void bind(ServerSocket socket, InetSocketAddress address, + public static void bind(ServerSocket socket, InetSocketAddress address, int backlog, Configuration conf, String rangeConf) throws IOException { try { IntegerRanges range = null; @@ -783,7 +783,7 @@ public RpcMetrics getRpcMetrics() { public RpcDetailedMetrics getRpcDetailedMetrics() { return rpcDetailedMetrics; } - + @VisibleForTesting Iterable getHandlers() { return Arrays.asList(handlers); @@ -1472,7 +1472,7 @@ public String toString() { /** Listens on the socket. Creates jobs for the handler threads*/ private class Listener extends Thread { - + private ServerSocketChannel acceptChannel = null; //the accept channel private Selector selector = null; //the selector that we use for the server private Reader[] readers = null; @@ -1519,7 +1519,7 @@ private class Listener extends Thread { void setIsAuxiliary() { this.isOnAuxiliaryPort = true; } - + private class Reader extends Thread { final private BlockingQueue pendingConnections; private final Selector readSelector; @@ -1531,7 +1531,7 @@ private class Reader extends Thread { new LinkedBlockingQueue(readerPendingConnectionQueue); this.readSelector = Selector.open(); } - + @Override public void run() { LOG.info("Starting " + Thread.currentThread().getName()); @@ -1635,7 +1635,7 @@ public void run() { } } catch (OutOfMemoryError e) { // we can run out of memory if we have too many threads - // log the event and sleep for a minute and give + // log the event and sleep for a minute and give // some thread(s) a chance to finish LOG.warn("Out of Memory in server select", e); closeCurrentConnection(key, e); @@ -1655,7 +1655,7 @@ public void run() { selector= null; acceptChannel= null; - + // close all connections connectionManager.stopIdleScan(); connectionManager.closeAll(); @@ -1675,7 +1675,7 @@ private void closeCurrentConnection(SelectionKey key, Throwable e) { InetSocketAddress getAddress() { return (InetSocketAddress)acceptChannel.socket().getLocalSocketAddress(); } - + void doAccept(SelectionKey key) throws InterruptedException, IOException, OutOfMemoryError { ServerSocketChannel server = (ServerSocketChannel) key.channel(); SocketChannel channel; @@ -1684,7 +1684,7 @@ void doAccept(SelectionKey key) throws InterruptedException, IOException, OutOf channel.configureBlocking(false); channel.socket().setTcpNoDelay(tcpNoDelay); channel.socket().setKeepAlive(true); - + Reader reader = getReader(); Connection c = connectionManager.register(channel, this.listenPort, this.isOnAuxiliaryPort); @@ -1705,10 +1705,10 @@ void doRead(SelectionKey key) throws InterruptedException { int count; Connection c = (Connection)key.attachment(); if (c == null) { - return; + return; } c.setLastContact(Time.now()); - + try { count = c.readAndProcess(); } catch (InterruptedException ieo) { @@ -1731,7 +1731,7 @@ void doRead(SelectionKey key) throws InterruptedException { else { c.setLastContact(Time.now()); } - } + } synchronized void doStop() { if (selector != null) { @@ -1749,7 +1749,7 @@ synchronized void doStop() { r.shutdown(); } } - + synchronized Selector getSelector() { return selector; } // The method that will return the next reader to work with // Simplistic implementation of round robin for now @@ -1786,7 +1786,7 @@ public void run() { } } } - + private void doRunLoop() { long lastPurgeTimeNanos = 0; // last check for old calls. @@ -1827,7 +1827,7 @@ private void doRunLoop() { // LOG.debug("Checking for old call responses."); ArrayList calls; - + // get the list of channels from list of keys. synchronized (writeSelector.keys()) { calls = new ArrayList(writeSelector.keys().size()); @@ -1835,7 +1835,7 @@ private void doRunLoop() { while (iter.hasNext()) { SelectionKey key = iter.next(); RpcCall call = (RpcCall)key.attachment(); - if (call != null && key.channel() == call.connection.channel) { + if (call != null && key.channel() == call.connection.channel) { calls.add(call); } } @@ -1884,7 +1884,7 @@ private void doAsyncWrite(SelectionKey key) throws IOException { } // - // Remove calls that have been pending in the responseQueue + // Remove calls that have been pending in the responseQueue // for a long time. // private void doPurge(RpcCall call, long now) { @@ -1947,18 +1947,18 @@ private boolean processResponse(LinkedList responseQueue, Thread.currentThread().getName(), call, numBytes); } else { // - // If we were unable to write the entire response out, then - // insert in Selector queue. + // If we were unable to write the entire response out, then + // insert in Selector queue. // call.connection.responseQueue.addFirst(call); - + if (inHandler) { // set the serve time when the response has to be sent later call.responseTimestampNanos = Time.monotonicNowNanos(); - + incPending(); try { - // Wakeup the thread blocked on select, only then can the call + // Wakeup the thread blocked on select, only then can the call // to channel.register() complete. writeSelector.wakeup(); channel.register(writeSelector, SelectionKey.OP_WRITE, call); @@ -2021,12 +2021,12 @@ private synchronized void waitPending() throws InterruptedException { public enum AuthProtocol { NONE(0), SASL(-33); - + public final int callId; AuthProtocol(int callId) { this.callId = callId; } - + static AuthProtocol valueOf(int callId) { for (AuthProtocol authType : AuthProtocol.values()) { if (authType.callId == callId) { @@ -2036,12 +2036,12 @@ static AuthProtocol valueOf(int callId) { return null; } }; - + /** * Wrapper for RPC IOExceptions to be returned to the client. Used to * let exceptions bubble up to top of processOneRpc where the correct * callId can be associated with the response. Also used to prevent - * unnecessary stack trace logging if it's not an internal server error. + * unnecessary stack trace logging if it's not an internal server error. */ @SuppressWarnings("serial") private static class FatalRpcServerException extends RpcServerException { @@ -2083,7 +2083,7 @@ public class Connection { private int dataLength; private Socket socket; - // Cache the remote host & port info so that even if the socket is + // Cache the remote host & port info so that even if the socket is // disconnected, we can say where it used to connect to. /** @@ -2123,13 +2123,13 @@ public class Connection { private boolean sentNegotiate = false; private boolean useWrap = false; - + public Connection(SocketChannel channel, long lastContact, int ingressPort, boolean isOnAuxiliaryPort) { this.channel = channel; this.lastContact = lastContact; this.data = null; - + // the buffer is initialized to read the "hrpc" and after that to read // the length of the Rpc-packet (i.e 4 bytes) this.dataLengthBuffer = ByteBuffer.allocate(4); @@ -2155,7 +2155,7 @@ public Connection(SocketChannel channel, long lastContact, socketSendBufferSize); } } - } + } @Override public String toString() { @@ -2214,17 +2214,17 @@ public Configuration getConf() { private boolean isIdle() { return rpcCount.get() == 0; } - + /* Decrement the outstanding RPC count */ private void decRpcCount() { rpcCount.decrementAndGet(); } - + /* Increment the outstanding RPC count */ private void incRpcCount() { rpcCount.incrementAndGet(); } - + private UserGroupInformation getAuthorizedUgi(String authorizedId) throws InvalidToken, AccessControlException { if (authMethod == AuthMethod.TOKEN) { @@ -2267,7 +2267,7 @@ private void saslReadAndProcess(RpcWritable.Buffer buffer) throws * that are wrapped as a cause of parameter e are unwrapped so that they can * be sent as the true cause to the client side. In case of * {@link InvalidToken} we go one level deeper to get the true cause. - * + * * @param e the exception that may have a cause we want to unwrap. * @return the true cause for some exceptions. */ @@ -2283,7 +2283,7 @@ private Throwable getTrueCause(IOException e) { // callbacks to only returning InvalidToken, but some services // need to throw other exceptions (ex. NN + StandyException), // so for now we'll tunnel the real exceptions via an - // InvalidToken's cause which normally is not set + // InvalidToken's cause which normally is not set if (cause.getCause() != null) { cause = cause.getCause(); } @@ -2293,15 +2293,15 @@ private Throwable getTrueCause(IOException e) { } return e; } - + /** * Process saslMessage and send saslResponse back * @param saslMessage received SASL message * @throws RpcServerException setup failed due to SASL negotiation - * failure, premature or invalid connection context, or other state - * errors. This exception needs to be sent to the client. This - * exception will wrap {@link RetriableException}, - * {@link InvalidToken}, {@link StandbyException} or + * failure, premature or invalid connection context, or other state + * errors. This exception needs to be sent to the client. This + * exception will wrap {@link RetriableException}, + * {@link InvalidToken}, {@link StandbyException} or * {@link SaslException}. * @throws IOException if sending reply fails * @throws InterruptedException @@ -2345,7 +2345,7 @@ private void saslProcess(RpcSaslProto saslMessage) throw tce; } } - + if (saslServer != null && saslServer.isComplete()) { if (LOG.isDebugEnabled()) { LOG.debug("SASL server context established. Negotiated QoP is {}.", @@ -2379,15 +2379,15 @@ private void saslProcess(RpcSaslProto saslMessage) } } } - + /** * Process a saslMessge. * @param saslMessage received SASL message * @return the sasl response to send back to client - * @throws SaslException if authentication or generating response fails, + * @throws SaslException if authentication or generating response fails, * or SASL protocol mixup * @throws IOException if a SaslServer cannot be created - * @throws AccessControlException if the requested authentication type + * @throws AccessControlException if the requested authentication type * is not supported or trying to re-attempt negotiation. * @throws InterruptedException */ @@ -2395,7 +2395,7 @@ private RpcSaslProto processSaslMessage(RpcSaslProto saslMessage) throws SaslException, IOException, AccessControlException, InterruptedException { final RpcSaslProto saslResponse; - final SaslState state = saslMessage.getState(); // required + final SaslState state = saslMessage.getState(); // required switch (state) { case NEGOTIATE: { if (sentNegotiate) { @@ -2518,7 +2518,7 @@ private void checkDataLength(int dataLength) throws IOException { throw new IOException(error); } else if (dataLength > maxDataLength) { String error = "Requested data length " + dataLength + - " is longer than maximum configured RPC length " + + " is longer than maximum configured RPC length " + maxDataLength + ". RPC came from " + getHostAddress(); LOG.warn(error); throw new IOException(error); @@ -2526,17 +2526,17 @@ private void checkDataLength(int dataLength) throws IOException { } /** - * This method reads in a non-blocking fashion from the channel: - * this method is called repeatedly when data is present in the channel; + * This method reads in a non-blocking fashion from the channel: + * this method is called repeatedly when data is present in the channel; * when it has enough data to process one rpc it processes that rpc. - * - * On the first pass, it processes the connectionHeader, - * connectionContext (an outOfBand RPC) and at most one RPC request that + * + * On the first pass, it processes the connectionHeader, + * connectionContext (an outOfBand RPC) and at most one RPC request that * follows that. On future passes it will process at most one RPC request. - * - * Quirky things: dataLengthBuffer (4 bytes) is used to read "hrpc" OR + * + * Quirky things: dataLengthBuffer (4 bytes) is used to read "hrpc" OR * rpc request length. - * + * * @return -1 in case of error, else num bytes read so far * @throws IOException - internal error that should not be returned to * client, typically failure to respond to client @@ -2547,11 +2547,11 @@ public int readAndProcess() throws IOException, InterruptedException { // dataLengthBuffer is used to read "hrpc" or the rpc-packet length int count = -1; if (dataLengthBuffer.remaining() > 0) { - count = channelRead(channel, dataLengthBuffer); - if (count < 0 || dataLengthBuffer.remaining() > 0) + count = channelRead(channel, dataLengthBuffer); + if (count < 0 || dataLengthBuffer.remaining() > 0) return count; } - + if (!connectionHeaderRead) { // Every connection is expected to send the header; // so far we read "hrpc" of the connection header. @@ -2567,7 +2567,7 @@ public int readAndProcess() throws IOException, InterruptedException { // TODO we should add handler for service class later this.setServiceClass(connectionHeaderBuf.get(1)); dataLengthBuffer.flip(); - + // Check if it looks like the user is hitting an IPC port // with an HTTP GET - this is a common error, so we can // send back a simple string indicating as much. @@ -2593,16 +2593,16 @@ public int readAndProcess() throws IOException, InterruptedException { setupBadVersionResponse(version); return -1; } - + // this may switch us into SIMPLE - authProtocol = initializeAuthContext(connectionHeaderBuf.get(2)); - + authProtocol = initializeAuthContext(connectionHeaderBuf.get(2)); + dataLengthBuffer.clear(); // clear to next read rpc packet len connectionHeaderBuf = null; connectionHeaderRead = true; continue; // connection header read, now read 4 bytes rpc packet len } - + if (data == null) { // just read 4 bytes - length of RPC packet dataLengthBuffer.flip(); dataLength = dataLengthBuffer.getInt(); @@ -2612,7 +2612,7 @@ public int readAndProcess() throws IOException, InterruptedException { } // Now read the RPC packet count = channelRead(channel, data); - + if (data.remaining() == 0) { dataLengthBuffer.clear(); // to read length of future rpc packets data.flip(); @@ -2625,7 +2625,7 @@ public int readAndProcess() throws IOException, InterruptedException { if (!isHeaderRead) { continue; } - } + } return count; } return -1; @@ -2637,7 +2637,7 @@ private AuthProtocol initializeAuthContext(int authType) if (authProtocol == null) { IOException ioe = new IpcException("Unknown auth protocol:" + authType); doSaslReply(ioe); - throw ioe; + throw ioe; } boolean isSimpleEnabled = enabledAuthMethods.contains(AuthMethod.SIMPLE); switch (authProtocol) { @@ -2660,10 +2660,10 @@ private AuthProtocol initializeAuthContext(int authType) } /** - * Process the Sasl's Negotiate request, including the optimization of + * Process the Sasl's Negotiate request, including the optimization of * accelerating token negotiation. - * @return the response to Negotiate request - the list of enabled - * authMethods and challenge if the TOKENS are supported. + * @return the response to Negotiate request - the list of enabled + * authMethods and challenge if the TOKENS are supported. * @throws SaslException - if attempt to generate challenge fails. * @throws IOException - if it fails to create the SASL server for Tokens */ @@ -2685,27 +2685,27 @@ private RpcSaslProto buildSaslNegotiateResponse() sentNegotiate = true; return negotiateMessage; } - + private SaslServer createSaslServer(AuthMethod authMethod) throws IOException, InterruptedException { final Map saslProps = saslPropsResolver.getServerProperties(addr, ingressPort); return new SaslRpcServer(authMethod).create(this, saslProps, secretManager); } - + /** * Try to set up the response to indicate that the client version * is incompatible with the server. This can contain special-case * code to speak enough of past IPC protocols to pass back * an exception to the caller. - * @param clientVersion the version the caller is using + * @param clientVersion the version the caller is using * @throws IOException */ private void setupBadVersionResponse(int clientVersion) throws IOException { String errMsg = "Server IPC version " + CURRENT_VERSION + " cannot communicate with client version " + clientVersion; ByteArrayOutputStream buffer = new ByteArrayOutputStream(); - + if (clientVersion >= 9) { // Versions >>9 understand the normal response RpcCall fakeCall = new RpcCall(this, -1); @@ -2731,7 +2731,7 @@ private void setupBadVersionResponse(int clientVersion) throws IOException { sendResponse(fakeCall); } } - + private void setupHttpRequestOnIpcPortResponse() throws IOException { RpcCall fakeCall = new RpcCall(this, 0); fakeCall.setResponse(ByteBuffer.wrap( @@ -2742,7 +2742,7 @@ private void setupHttpRequestOnIpcPortResponse() throws IOException { /** Reads the connection context following the connection header * @throws RpcServerException - if the header cannot be * deserialized, or the user is not authorized - */ + */ private void processConnectionContext(RpcWritable.Buffer buffer) throws RpcServerException { // allow only one connection context during a session @@ -2762,7 +2762,7 @@ private void processConnectionContext(RpcWritable.Buffer buffer) // user is authenticated user.setAuthenticationMethod(authMethod); //Now we check if this is a proxy user case. If the protocol user is - //different from the 'user', it is a proxy user scenario. However, + //different from the 'user', it is a proxy user scenario. However, //this is not allowed if user authenticated with DIGEST. if ((protocolUser != null) && (!protocolUser.getUserName().equals(user.getUserName()))) { @@ -2790,14 +2790,14 @@ private void processConnectionContext(RpcWritable.Buffer buffer) connectionManager.incrUserConnections(user.getShortUserName()); } } - + /** * Process a wrapped RPC Request - unwrap the SASL packet and process - * each embedded RPC request + * each embedded RPC request * @param inBuf - SASL wrapped request of one or more RPCs * @throws IOException - SASL packet cannot be unwrapped * @throws InterruptedException - */ + */ private void unwrapPacketAndProcessRpcs(byte[] inBuf) throws IOException, InterruptedException { LOG.debug("Have read input token of size {} for processing by saslServer.unwrap()", @@ -2833,18 +2833,18 @@ private void unwrapPacketAndProcessRpcs(byte[] inBuf) } } } - + /** - * Process one RPC Request from buffer read from socket stream + * Process one RPC Request from buffer read from socket stream * - decode rpc in a rpc-Call * - handle out-of-band RPC requests such as the initial connectionContext * - A successfully decoded RpcCall will be deposited in RPC-Q and * its response will be sent later when the request is processed. - * + * * Prior to this call the connectionHeader ("hrpc...") has been handled and * if SASL then SASL has been established and the buf we are passed * has been unwrapped from SASL. - * + * * @param bb - contains the RPC request header and the rpc request * @throws IOException - internal error that should not be returned to * client, typically failure to respond to client @@ -2902,15 +2902,15 @@ private void checkRpcHeaders(RpcRequestHeaderProto header) throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err); } - if (header.getRpcOp() != + if (header.getRpcOp() != RpcRequestHeaderProto.OperationProto.RPC_FINAL_PACKET) { - String err = "IPC Server does not implement rpc header operation" + + String err = "IPC Server does not implement rpc header operation" + header.getRpcOp(); throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err); } // If we know the rpc kind, get its class so that we can deserialize - // (Note it would make more sense to have the handler deserialize but + // (Note it would make more sense to have the handler deserialize but // we continue with this original design. if (!header.hasRpcKind()) { String err = " IPC Server: No rpc kind in rpcRequestHeader"; @@ -2920,7 +2920,7 @@ private void checkRpcHeaders(RpcRequestHeaderProto header) } /** - * Process an RPC Request + * Process an RPC Request * - the connection headers and context must have been already read. * - Based on the rpcKind, decode the rpcRequest. * - A successfully decoded RpcCall will be deposited in RPC-Q and @@ -2937,12 +2937,12 @@ private void checkRpcHeaders(RpcRequestHeaderProto header) private void processRpcRequest(RpcRequestHeaderProto header, RpcWritable.Buffer buffer) throws RpcServerException, InterruptedException { - Class rpcRequestClass = + Class rpcRequestClass = getRpcRequestWrapper(header.getRpcKind()); if (rpcRequestClass == null) { - LOG.warn("Unknown rpc kind " + header.getRpcKind() + + LOG.warn("Unknown rpc kind " + header.getRpcKind() + " from client " + getHostAddress()); - final String err = "Unknown rpc kind in rpc header" + + final String err = "Unknown rpc kind in rpc header" + header.getRpcKind(); throw new FatalRpcServerException( RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, err); @@ -3058,7 +3058,7 @@ private void processRpcRequest(RpcRequestHeaderProto header, * @param buffer - stream to request payload * @throws RpcServerException - setup failed due to SASL * negotiation failure, premature or invalid connection context, - * or other state errors. This exception needs to be sent to the + * or other state errors. This exception needs to be sent to the * client. * @throws IOException - failed to send a response back to the client * @throws InterruptedException @@ -3091,7 +3091,7 @@ private void processRpcOutOfBandRequest(RpcRequestHeaderProto header, RpcErrorCodeProto.FATAL_INVALID_RPC_HEADER, "Unknown out of band call #" + callId); } - } + } /** * Authorize proxy users to access this server @@ -3119,9 +3119,9 @@ private void authorizeConnection() throws RpcServerException { RpcErrorCodeProto.FATAL_UNAUTHORIZED, ae); } } - + /** - * Decode the a protobuf from the given input stream + * Decode the a protobuf from the given input stream * @return Message - decoded protobuf * @throws RpcServerException - deserialization failed */ @@ -3343,33 +3343,33 @@ void logException(Logger logger, Throwable e, Call call) { logger.info(logMsg, e); } } - + protected Server(String bindAddress, int port, - Class paramClass, int handlerCount, + Class paramClass, int handlerCount, Configuration conf) - throws IOException + throws IOException { this(bindAddress, port, paramClass, handlerCount, -1, -1, conf, Integer .toString(port), null, null); } - + protected Server(String bindAddress, int port, Class rpcRequestClass, int handlerCount, int numReaders, int queueSizePerHandler, Configuration conf, String serverName, SecretManager secretManager) throws IOException { - this(bindAddress, port, rpcRequestClass, handlerCount, numReaders, + this(bindAddress, port, rpcRequestClass, handlerCount, numReaders, queueSizePerHandler, conf, serverName, secretManager, null); } - - /** + + /** * Constructs a server listening on the named port and address. Parameters passed must * be of the named class. The handlerCount determines * the number of handler threads that will be used to process calls. * If queueSizePerHandler or numReaders are not -1 they will be used instead of parameters * from configuration. Otherwise the configuration will be picked up. - * - * If rpcRequestClass is null then the rpcRequestClass must have been + * + * If rpcRequestClass is null then the rpcRequestClass must have been * registered via {@link #registerProtocolEngine(RPC.RpcKind, * Class, RPC.RpcInvoker)} * This parameter has been retained for compatibility with existing tests @@ -3398,7 +3398,7 @@ protected Server(String bindAddress, int port, this.conf = conf; this.portRangeConfig = portRangeConfig; this.port = port; - this.rpcRequestClass = rpcRequestClass; + this.rpcRequestClass = rpcRequestClass; this.handlerCount = handlerCount; this.socketSendBufferSize = 0; this.serverName = serverName; @@ -3410,7 +3410,7 @@ protected Server(String bindAddress, int port, } else { this.maxQueueSize = handlerCount * conf.getInt( CommonConfigurationKeys.IPC_SERVER_HANDLER_QUEUE_SIZE_KEY, - CommonConfigurationKeys.IPC_SERVER_HANDLER_QUEUE_SIZE_DEFAULT); + CommonConfigurationKeys.IPC_SERVER_HANDLER_QUEUE_SIZE_DEFAULT); } this.maxRespSize = conf.getInt( CommonConfigurationKeys.IPC_SERVER_RPC_MAX_RESPONSE_SIZE_KEY, @@ -3435,14 +3435,14 @@ protected Server(String bindAddress, int port, maxQueueSize, prefix, conf); this.secretManager = (SecretManager) secretManager; - this.authorize = - conf.getBoolean(CommonConfigurationKeys.HADOOP_SECURITY_AUTHORIZATION, + this.authorize = + conf.getBoolean(CommonConfigurationKeys.HADOOP_SECURITY_AUTHORIZATION, false); // configure supported authentications this.enabledAuthMethods = getAuthMethods(secretManager, conf); this.negotiateResponse = buildNegotiateResponse(enabledAuthMethods); - + // Start the listener here and let it bind to the port listener = new Listener(port); // set the server port to the default listener port. @@ -3468,12 +3468,12 @@ protected Server(String bindAddress, int port, // Create the responder here responder = new Responder(); - + if (secretManager != null || UserGroupInformation.isSecurityEnabled()) { SaslRpcServer.init(conf); saslPropsResolver = SaslPropertiesResolver.getInstance(conf); } - + this.exceptionsHandler.addTerseLoggingExceptions(StandbyException.class); this.exceptionsHandler.addTerseLoggingExceptions( HealthCheckFailedException.class); @@ -3534,7 +3534,7 @@ private RpcSaslProto buildNegotiateResponse(List authMethods) } else { negotiateBuilder.setState(SaslState.NEGOTIATE); for (AuthMethod authMethod : authMethods) { - SaslRpcServer saslRpcServer = new SaslRpcServer(authMethod); + SaslRpcServer saslRpcServer = new SaslRpcServer(authMethod); SaslAuth.Builder builder = negotiateBuilder.addAuthsBuilder() .setMethod(authMethod.toString()) .setMechanism(saslRpcServer.mechanism); @@ -3555,31 +3555,31 @@ private RpcSaslProto buildNegotiateResponse(List authMethods) private List getAuthMethods(SecretManager secretManager, Configuration conf) { AuthenticationMethod confAuthenticationMethod = - SecurityUtil.getAuthenticationMethod(conf); + SecurityUtil.getAuthenticationMethod(conf); List authMethods = new ArrayList(); if (confAuthenticationMethod == AuthenticationMethod.TOKEN) { if (secretManager == null) { throw new IllegalArgumentException(AuthenticationMethod.TOKEN + " authentication requires a secret manager"); - } + } } else if (secretManager != null) { LOG.debug("{} authentication enabled for secret manager", AuthenticationMethod.TOKEN); // most preferred, go to the front of the line! authMethods.add(AuthenticationMethod.TOKEN.getAuthMethod()); } - authMethods.add(confAuthenticationMethod.getAuthMethod()); - + authMethods.add(confAuthenticationMethod.getAuthMethod()); + LOG.debug("Server accepts auth methods:{}", authMethods); return authMethods; } - + private void closeConnection(Connection connection) { connectionManager.close(connection); } /** * Setup response for the IPC Call. - * + * * @param call {@link Call} to which we are setting up the response * @param status of the IPC call * @param rv return value for the IPC Call, if the call was successful @@ -3695,11 +3695,11 @@ private static int getDelimitedLength(Message message) { } /** - * Setup response for the IPC Call on Fatal Error from a + * Setup response for the IPC Call on Fatal Error from a * client that is using old version of Hadoop. * The response is serialized using the previous protocol's response * layout. - * + * * @param response buffer to serialize the response into * @param call {@link Call} to which we are setting up the response * @param rv return value for the IPC Call, if the call was successful @@ -3707,9 +3707,9 @@ private static int getDelimitedLength(Message message) { * @param error error message, if the call failed * @throws IOException */ - private void setupResponseOldVersionFatal(ByteArrayOutputStream response, + private void setupResponseOldVersionFatal(ByteArrayOutputStream response, RpcCall call, - Writable rv, String errorClass, String error) + Writable rv, String errorClass, String error) throws IOException { final int OLD_VERSION_FATAL_STATUS = -1; response.reset(); @@ -3742,11 +3742,11 @@ private void wrapWithSasl(RpcCall call) throws IOException { setupResponse(call, saslHeader, RpcWritable.wrap(saslMessage)); } } - + Configuration getConf() { return conf; } - + /** * Sets the socket buffer size used for responding to RPCs. * @param size input size. @@ -3768,7 +3768,7 @@ public synchronized void start() { } handlers = new Handler[handlerCount]; - + for (int i = 0; i < handlerCount; i++) { handlers[i] = new Handler(i); handlers[i].start(); @@ -3851,9 +3851,9 @@ public synchronized Set getAuxiliaryListenerAddresses() { } return allAddrs; } - - /** - * Called for each call. + + /** + * Called for each call. * @deprecated Use {@link #call(RPC.RpcKind, String, * Writable, long)} instead * @param param input param. @@ -3865,7 +3865,7 @@ public synchronized Set getAuxiliaryListenerAddresses() { public Writable call(Writable param, long receiveTime) throws Exception { return call(RPC.RpcKind.RPC_BUILTIN, null, param, receiveTime); } - + /** * Called for each call. * @param rpcKind input rpcKind. @@ -3877,10 +3877,10 @@ public Writable call(Writable param, long receiveTime) throws Exception { */ public abstract Writable call(RPC.RpcKind rpcKind, String protocol, Writable param, long receiveTime) throws Exception; - + /** * Authorize the incoming client connection. - * + * * @param user client user * @param protocolName - the protocol * @param addr InetAddress of incoming connection @@ -3896,13 +3896,13 @@ private void authorize(UserGroupInformation user, String protocolName, try { protocol = getProtocolClass(protocolName, getConf()); } catch (ClassNotFoundException cfne) { - throw new AuthorizationException("Unknown protocol: " + + throw new AuthorizationException("Unknown protocol: " + protocolName); } serviceAuthorizationManager.authorize(user, protocol, getConf(), addr); } } - + /** * Get the port on which the IPC Server is listening for incoming connections. * This could be an ephemeral port too, in which case we return the real @@ -3912,7 +3912,7 @@ private void authorize(UserGroupInformation user, String protocolName, public int getPort() { return port; } - + /** * The number of open RPC conections * @return the number of open rpc connections @@ -3987,25 +3987,25 @@ public int getNumReaders() { } /** - * When the read or write buffer size is larger than this limit, i/o will be + * When the read or write buffer size is larger than this limit, i/o will be * done in chunks of this size. Most RPC requests and responses would be * be smaller. */ private static int NIO_BUFFER_LIMIT = 8*1024; //should not be more than 64KB. - + /** * This is a wrapper around {@link WritableByteChannel#write(ByteBuffer)}. - * If the amount of data is large, it writes to channel in smaller chunks. - * This is to avoid jdk from creating many direct buffers as the size of + * If the amount of data is large, it writes to channel in smaller chunks. + * This is to avoid jdk from creating many direct buffers as the size of * buffer increases. This also minimizes extra copies in NIO layer - * as a result of multiple write operations required to write a large - * buffer. + * as a result of multiple write operations required to write a large + * buffer. * * @see WritableByteChannel#write(ByteBuffer) */ - private int channelWrite(WritableByteChannel channel, + private int channelWrite(WritableByteChannel channel, ByteBuffer buffer) throws IOException { - + int count = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? channel.write(buffer) : channelIO(null, channel, buffer); if (count > 0) { @@ -4013,19 +4013,19 @@ private int channelWrite(WritableByteChannel channel, } return count; } - - + + /** * This is a wrapper around {@link ReadableByteChannel#read(ByteBuffer)}. - * If the amount of data is large, it writes to channel in smaller chunks. - * This is to avoid jdk from creating many direct buffers as the size of + * If the amount of data is large, it writes to channel in smaller chunks. + * This is to avoid jdk from creating many direct buffers as the size of * ByteBuffer increases. There should not be any performance degredation. - * + * * @see ReadableByteChannel#read(ByteBuffer) */ - private int channelRead(ReadableByteChannel channel, + private int channelRead(ReadableByteChannel channel, ByteBuffer buffer) throws IOException { - + int count = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? channel.read(buffer) : channelIO(channel, null, buffer); if (count > 0) { @@ -4033,43 +4033,43 @@ private int channelRead(ReadableByteChannel channel, } return count; } - + /** * Helper for {@link #channelRead(ReadableByteChannel, ByteBuffer)} * and {@link #channelWrite(WritableByteChannel, ByteBuffer)}. Only * one of readCh or writeCh should be non-null. - * + * * @see #channelRead(ReadableByteChannel, ByteBuffer) * @see #channelWrite(WritableByteChannel, ByteBuffer) */ - private static int channelIO(ReadableByteChannel readCh, + private static int channelIO(ReadableByteChannel readCh, WritableByteChannel writeCh, ByteBuffer buf) throws IOException { - + int originalLimit = buf.limit(); int initialRemaining = buf.remaining(); int ret = 0; - + while (buf.remaining() > 0) { try { int ioSize = Math.min(buf.remaining(), NIO_BUFFER_LIMIT); buf.limit(buf.position() + ioSize); - - ret = (readCh == null) ? writeCh.write(buf) : readCh.read(buf); - + + ret = (readCh == null) ? writeCh.write(buf) : readCh.read(buf); + if (ret < ioSize) { break; } } finally { - buf.limit(originalLimit); + buf.limit(originalLimit); } } - int nBytes = initialRemaining - buf.remaining(); + int nBytes = initialRemaining - buf.remaining(); return (nBytes > 0) ? nBytes : ret; } - + private class ConnectionManager { final private AtomicInteger count = new AtomicInteger(); final private AtomicLong droppedConnections = new AtomicLong(); @@ -4084,7 +4084,7 @@ private class ConnectionManager { final private int maxIdleTime; final private int maxIdleToClose; final private int maxConnections; - + ConnectionManager() { this.idleScanTimer = new Timer( "IPC Server idle connection scanner for port " + getPort(), true); @@ -4118,7 +4118,7 @@ private boolean add(Connection connection) { } return added; } - + private boolean remove(Connection connection) { boolean removed = connections.remove(connection); if (removed) { @@ -4189,7 +4189,7 @@ Connection register(SocketChannel channel, int ingressPort, connection, size(), callQueue.size()); return connection; } - + boolean close(Connection connection) { boolean exists = remove(connection); if (exists) { @@ -4207,7 +4207,7 @@ boolean close(Connection connection) { } return exists; } - + // synch'ed to avoid explicit invocation upon OOM from colliding with // timer task firing synchronized void closeIdle(boolean scanAll) { @@ -4230,7 +4230,7 @@ synchronized void closeIdle(boolean scanAll) { } } } - + void closeAll() { // use a copy of the connections to be absolutely sure the concurrent // iterator doesn't miss a connection @@ -4238,15 +4238,15 @@ void closeAll() { close(connection); } } - + void startIdleScan() { scheduleIdleScanTask(); } - + void stopIdleScan() { idleScanTimer.cancel(); } - + private void scheduleIdleScanTask() { if (!running) { return; diff --git a/hadoop-common-project/hadoop-common/src/main/proto/RpcHeader.proto b/hadoop-common-project/hadoop-common/src/main/proto/RpcHeader.proto index 71a75f11b9097..19bdc96726b0e 100644 --- a/hadoop-common-project/hadoop-common/src/main/proto/RpcHeader.proto +++ b/hadoop-common-project/hadoop-common/src/main/proto/RpcHeader.proto @@ -29,7 +29,7 @@ package hadoop.common; /** * This is the rpc request header. It is sent with every rpc call. - * + * * The format of RPC call is as follows: * +--------------------------------------------------------------+ * | Rpc length in bytes (4 bytes int) sum of next two parts | @@ -47,12 +47,12 @@ package hadoop.common; */ enum RpcKindProto { RPC_BUILTIN = 0; // Used for built in calls by tests - RPC_WRITABLE = 1; // Use WritableRpcEngine + RPC_WRITABLE = 1; // Use WritableRpcEngine RPC_PROTOCOL_BUFFER = 2; // Use ProtobufRpcEngine } - + /** * Used to pass through the information necessary to continue * a trace after an RPC is made. All we need is the traceid @@ -117,12 +117,12 @@ message RpcRequestHeaderProto { // the header for the RpcRequest * | The rpc response header contains the necessary info | * +------------------------------------------------------------------+ * - * Note that rpc response header is also used when connection setup fails. + * Note that rpc response header is also used when connection setup fails. * Ie the response looks like a rpc response with a fake callId. */ message RpcResponseHeaderProto { /** - * + * * RpcStastus - success or failure * The reponseHeader's errDetail, exceptionClassName and errMsg contains * further details on the error @@ -178,7 +178,7 @@ message RpcSaslProto { RESPONSE = 4; WRAP = 5; } - + message SaslAuth { required string method = 1; required string mechanism = 2; @@ -187,7 +187,7 @@ message RpcSaslProto { optional bytes challenge = 5; } - optional uint32 version = 1; + optional uint32 version = 1; required SaslState state = 2; optional bytes token = 3; repeated SaslAuth auths = 4; From 2834b19379e2a07622cb7de903bcd3f5e0c95c8d Mon Sep 17 00:00:00 2001 From: Tom McCormick Date: Tue, 22 Jul 2025 14:45:43 -0400 Subject: [PATCH 6/6] Always clear auth context --- .../src/main/java/org/apache/hadoop/ipc/Server.java | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java index 37a1e8e499240..ca7460a653c9a 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Server.java @@ -2992,11 +2992,9 @@ private void processRpcRequest(RpcRequestHeaderProto header, // Set AuthorizationContext for this thread if present byte[] authHeader = null; - boolean authzSet = false; try { if (header.hasAuthorizationHeader()) { authHeader = header.getAuthorizationHeader().toByteArray(); - authzSet = true; } RpcCall call = new RpcCall(this, header.getCallId(), @@ -3045,9 +3043,7 @@ private void processRpcRequest(RpcRequestHeaderProto header, } incRpcCount(); // Increment the rpc count } finally { - if (authzSet) { - AuthorizationContext.clear(); - } + AuthorizationContext.clear(); } }