summaryrefslogtreecommitdiffstats
path: root/src/main/java/org/openslx/thrifthelper/TBinaryProtocolSafe.java
blob: 86a2306944a099b4e98efbe17625073fa1b07cb5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
package org.openslx.thrifthelper;

import java.io.UnsupportedEncodingException;
import java.nio.ByteBuffer;

import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolException;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.transport.TTransport;

/**
 * Binary protocol implementation for thrift.
 * Will not read messages bigger than 12MiB.
 * 
 */
public class TBinaryProtocolSafe extends TBinaryProtocol
{
	/**
	 * Factory
	 */
	@SuppressWarnings( "serial" )
	public static class Factory implements TProtocolFactory
	{

		protected boolean strictRead_ = false;
		protected boolean strictWrite_ = true;

		public Factory()
		{
			this( false, true );
		}

		public Factory(boolean strictRead, boolean strictWrite)
		{
			strictRead_ = strictRead;
			strictWrite_ = strictWrite;
		}

		public TProtocol getProtocol( TTransport trans )
		{
			return new TBinaryProtocolSafe( trans, strictRead_, strictWrite_ );
		}
	}

	private static final int maxLen = 12 * 1024 * 1024; // 12 MiB

	/**
	 * Constructor
	 */
	public TBinaryProtocolSafe(TTransport trans)
	{
		this( trans, false, true );
	}

	public TBinaryProtocolSafe(TTransport trans, boolean strictRead, boolean strictWrite)
	{
		super( trans );
		strictRead_ = strictRead;
		strictWrite_ = strictWrite;
	}

	/**
	 * Reading methods.
	 */

	public TMessage readMessageBegin() throws TException
	{
		int size = readI32();
		if ( size > maxLen )
			throw new TProtocolException( TProtocolException.SIZE_LIMIT, "Payload too big." );
		if ( size < 0 ) {
			int version = size & VERSION_MASK;
			if ( version != VERSION_1 ) {
				throw new TProtocolException( TProtocolException.BAD_VERSION, "Bad version in readMessageBegin" );
			}
			return new TMessage( readString(), (byte) ( size & 0x000000ff ), readI32() );
		} else {
			if ( strictRead_ ) {
				throw new TProtocolException( TProtocolException.BAD_VERSION, "Missing version in readMessageBegin, old client?" );
			}
			return new TMessage( readStringBody( size ), readByte(), readI32() );
		}
	}

	public String readString() throws TException
	{
		int size = readI32();
		if ( size > maxLen )
			throw new TProtocolException( TProtocolException.SIZE_LIMIT, "Payload too big." );
		if ( trans_.getBytesRemainingInBuffer() >= size ) {
			try {
				String s = new String( trans_.getBuffer(), trans_.getBufferPosition(), size, "UTF-8" );
				trans_.consumeBuffer( size );
				return s;
			} catch ( UnsupportedEncodingException e ) {
				throw new TException( "JVM DOES NOT SUPPORT UTF-8" );
			}
		}

		return readStringBody( size );
	}

	public ByteBuffer readBinary() throws TException
	{
		int size = readI32();
		if ( size > maxLen )
			throw new TProtocolException( TProtocolException.SIZE_LIMIT, "Payload too big." );
		if ( trans_.getBytesRemainingInBuffer() >= size ) {
			ByteBuffer bb = ByteBuffer.wrap( trans_.getBuffer(), trans_.getBufferPosition(), size );
			trans_.consumeBuffer( size );
			return bb;
		}

		byte[] buf = new byte[ size ];
		trans_.readAll( buf, 0, size );
		return ByteBuffer.wrap( buf );
	}

}