summaryrefslogtreecommitdiffstats
path: root/src/main/java/org/openslx/thrifthelper/ThriftHandler.java
blob: 0d30ccb79786a665f998640df8cc202f64ce9f94 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
package org.openslx.thrifthelper;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;

import org.apache.log4j.Logger;
import org.apache.thrift.TException;
import org.apache.thrift.transport.TTransportException;

class ThriftHandler<T extends Object> implements InvocationHandler
{

	private final static Logger LOGGER = Logger.getLogger( ThriftHandler.class );

	public interface EventCallback<T>
	{
		public T getNewClient();

		public void error( Throwable t, String message );
	}

	private final ThreadLocal<T> clients = new ThreadLocal<T>();
	private final EventCallback<T> callback;

	public ThriftHandler( final Class<T> clazz, EventCallback<T> cb )
	{
		callback = cb;
		thriftMethods = Collections.unmodifiableSet( new HashSet<String>() {
			private static final long serialVersionUID = 8983506538154055231L;
			{
				Method[] methods = clazz.getMethods();
				for ( int i = 0; i < methods.length; i++ ) {
					boolean thrift = false;
					Class<?>[] type = methods[i].getExceptionTypes();
					for ( int e = 0; e < type.length; e++ ) {
						if ( TException.class.isAssignableFrom( type[e] ) )
							thrift = true;

					}
					String name = methods[i].getName();
					if ( thrift && !name.startsWith( "send_" ) && !name.startsWith( "recv_" ) ) {
						add( name );
					}
				}
			}
		} );
	}

	private final Set<String> thriftMethods;

	public Object invoke( Object tproxy, Method method, Object[] args ) throws Throwable
	{

		// first find the thrift methods
		if ( !thriftMethods.contains( method.getName() ) ) {
			try {
				return method.invoke( method, args );
			} catch ( InvocationTargetException e ) {
				// TODO Auto-generated catch block
				Throwable cause = e.getCause();
				if ( cause == null ) {
					throw new RuntimeException();
				}
				throw cause;
			}
		}
		LOGGER.debug( "Proxying '" + method.getName() + "'" );

		T client = getClient( false );
		Throwable cause = null;
		for ( int i = 0; i < 3; i++ ) {
			try {
				return method.invoke( client, args );
			} catch ( InvocationTargetException e ) {
				cause = e.getCause();
				if ( cause instanceof TTransportException ) {
					LOGGER.debug( "Transport error - re-initialising ..." );
					// new client
					client = getClient( true );
				}
			}
		}
		
		// Uh oh
		callback.error( cause, "Could not reconnect to thrift server - network or server down?" );

		if ( cause != null )
			throw cause;
		return null;

	}

	private T getClient( boolean forceNew )
	{
		T client = clients.get();
		if ( client != null && !forceNew ) {
			return client;
		}
		client = callback.getNewClient();
		if ( client == null ) {
			// TODO own exception
			throw new RuntimeException();
		}
		clients.set( client );
		return client;
	}
}