summaryrefslogblamecommitdiffstats
path: root/src/main/java/org/openslx/thrifthelper/TBinaryProtocolSafe.java
blob: efa83c31fbdb1172e28ad455e38a68f62a6cf189 (plain) (tree)
1
2
3
4
5
6
7
8
9
10

                                 


                                  
                                
                           


                                         
 

                                           





                                                     
                                           
                                              
                                                       
                                                             







                                                        


                                                                                         


                  


                                                               




                                                                                  































                                                                                             
                                                                        

         
          

                           
 
                 

                                                            








                                                                                                                               


                                                                                                               
                                 


                                                                                                                       

                                                                                                                                                                 

                                                                                                 
                                                                                                                                                      

                                                                                                             
                                                                  



                                                                                                 
                                    
                                                                                                                    


                                                          

                                                                                                             









                                                                                                                                                   




















                                                                                               
                 





                                                                                                          


                                                                                                                              




                                              
                 
















                                                                                                                
package org.openslx.thrifthelper;

import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;

import javax.net.ssl.SSLException;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolException;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.apache.thrift.transport.layered.TLayeredTransport;

/**
 * Binary protocol implementation for thrift.
 * Will not read messages bigger than 12MiB.
 * 
 */
public class TBinaryProtocolSafe extends TBinaryProtocol
{

	private final static Logger LOGGER = LogManager.getLogger( ThriftHandler.class );

	/**
	 * Factory
	 */
	public static class Factory implements TProtocolFactory
	{

		/**
		 * Version for serialization.
		 */
		private static final long serialVersionUID = 6896537370338823740L;

		protected boolean strictRead_ = false;
		protected boolean strictWrite_ = true;

		public Factory()
		{
			this( false, true );
		}

		public Factory(boolean strictRead, boolean strictWrite)
		{
			strictRead_ = strictRead;
			strictWrite_ = strictWrite;
		}

		public TProtocol getProtocol( TTransport trans )
		{
			return new TBinaryProtocolSafe( trans, strictRead_, strictWrite_ );
		}
	}

	private static final int maxLen = 12 * 1024 * 1024; // 12 MiB

	/**
	 * Constructor
	 */
	public TBinaryProtocolSafe(TTransport trans)
	{
		this( trans, false, true );
	}

	public TBinaryProtocolSafe(TTransport trans, boolean strictRead, boolean strictWrite)
	{
		super( trans, maxLen, maxLen, strictRead, strictWrite );
	}

	/*
	 * Reading methods.
	 */

	@Override
	public TMessage readMessageBegin() throws TException
	{
		int size;
		try {
			size = readI32();
		} catch ( TTransportException e ) {
			// Do this to suppress certain SSL handshake errors that result from port scanning and service probing
			if ( e.getCause() instanceof SSLException ) {
				String m = e.getCause().getMessage();
				// We still want SSL errors that help diagnosing more specific SSL errors that relate to actual
				// SSL handshake attempts, like incompatible TLS versions or ciphers.
				if ( !m.contains( "Remote host terminated the handshake" )
						&& !m.contains( "Unsupported or unrecognized SSL message" ) ) {
					LOGGER.warn( getIp() + m );
				}
				// Fake an END_OF_FILE exception, as the logException() method in the server class will
				// ignore there. Let's hope it will stay ignored in the future.
				throw new TTransportException( TTransportException.END_OF_FILE );
			} else if ( e.getCause() instanceof SocketException
					&& ( e.getCause().getMessage().contains( " timed out" ) || e.getCause().getMessage().contains( "Connection reset" ) ) ) {
				// Faaaake
				throw new TTransportException( TTransportException.END_OF_FILE );
			} else if ( e.getMessage().contains( "larger than max length" ) || e.getMessage().contains( "Read a negative frame size" ) ) {
				// Also fake, since this one prints a whole stack trace compared to the other
				// message by AbstractNonblockingServer
				LOGGER.debug( e.getMessage(), e );
				throw new TTransportException( TTransportException.END_OF_FILE );
			}
			throw e;
		}
		if ( size > maxLen )
			throw new TProtocolException( TProtocolException.SIZE_LIMIT, getIp() + "Payload too big." );
		if ( size < 0 ) {
			int version = size & VERSION_MASK;
			if ( version != VERSION_1 ) {
				LOGGER.warn( getIp() + "Bad version (" + version + ") in readMessageBegin" );
				throw new TTransportException( TTransportException.END_OF_FILE );
			}
			return new TMessage( readString(), (byte) ( size & 0x000000ff ), readI32() );
		} else {
			if ( strictRead_ ) {
				throw new TProtocolException( TProtocolException.BAD_VERSION, "Missing version in readMessageBegin, old client?" );
			}
			return new TMessage( readStringBody( size ), readByte(), readI32() );
		}
	}

	private String getIp()
	{
		TTransport t = trans_;
		while ( t instanceof TLayeredTransport ) {
			t = ( (TLayeredTransport)t ).getInnerTransport();
		}
		InetAddress ia = null;
		if ( t instanceof TSocket ) {
			SocketAddress sa = ( (TSocket)t ).getSocket().getRemoteSocketAddress();
			if ( sa != null && ( sa instanceof InetSocketAddress ) )
				ia = ( (InetSocketAddress)sa ).getAddress();
			if ( ia == null )
				ia = ( (TSocket)t ).getSocket().getInetAddress();
		} else {
			LOGGER.debug( "getIp(" + t.getClass().getSimpleName() + ")" );
		}
		if ( ia == null )
			return "";
		return ia.getHostAddress() + ": ";
	}

	@Override
	public String readString() throws TException
	{
		int size = readI32();
		if ( size > maxLen )
			throw new TProtocolException( TProtocolException.SIZE_LIMIT, "Payload too big." );
		if ( trans_.getBytesRemainingInBuffer() >= size ) {
			String s = new String( trans_.getBuffer(), trans_.getBufferPosition(), size, StandardCharsets.UTF_8 );
			trans_.consumeBuffer( size );
			return s;
		}

		return readStringBody( size );
	}

	@Override
	public ByteBuffer readBinary() throws TException
	{
		int size = readI32();
		if ( size > maxLen )
			throw new TProtocolException( TProtocolException.SIZE_LIMIT, "Payload too big." );
		if ( trans_.getBytesRemainingInBuffer() >= size ) {
			ByteBuffer bb = ByteBuffer.wrap( trans_.getBuffer(), trans_.getBufferPosition(), size );
			trans_.consumeBuffer( size );
			return bb;
		}

		byte[] buf = new byte[ size ];
		trans_.readAll( buf, 0, size );
		return ByteBuffer.wrap( buf );
	}

}