package org.openslx.taskmanager.tasks;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.openslx.taskmanager.api.AbstractTask;
import com.google.gson.annotations.Expose;
import com.jcraft.jsch.ChannelExec;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;
public class RemoteReboot extends AbstractTask
{
@Expose
private Client[] clients;
@Expose
private boolean shutdown;
@Expose
private int minutes;
@Expose
private String locationId;
@Expose
private String locationName;
@Expose
private String sshkey;
@Expose
private int port;
private JSch sshClient;
private Output status = new Output();
private static final String REBOOT_CMD = "/opt/openslx/scripts/idleaction-scheduled_action --detach reboot";
private static final String SHUTDOWN_CMD = "/opt/openslx/scripts/idleaction-scheduled_action --detach poweroff";
@Override
protected boolean initTask()
{
this.setStatusObject( this.status );
if ( minutes < 0 ) {
status.addError( "Delay cannot be negative" );
}
if ( sshkey == null || sshkey.length() == 0 ) {
status.addError( "No SSH key given" );
}
if ( port < 1 || port > 65535 ) {
status.addError( "Invalid port number" );
}
if ( status.error != null )
return false;
status.clients = clients;
Date shutdownTime = new Date( System.currentTimeMillis() + minutes * 60 * 1000 );
SimpleDateFormat sdf = new SimpleDateFormat( "HH:mm" );
status.time = sdf.format( shutdownTime );
status.locationId = locationId;
status.locationName = locationName;
JSch.setConfig( "StrictHostKeyChecking", "no" );
sshClient = new JSch();
try {
sshClient.addIdentity( "", sshkey.getBytes(), null, null );
} catch ( JSchException e ) {
status.addError( e.toString() );
return false;
}
return true;
}
@Override
protected boolean execute()
{
if ( clients.length == 0 )
return true;
final List<Client> rebootingClients = new ArrayList<>();
// try to connect to every client and start the reboot/shutdown process
ExecutorService tp = Executors.newFixedThreadPool( clients.length > 4 ? 4 : clients.length );
for ( final Client client : clients ) {
if ( client == null || client.clientip == null || client.machineuuid == null ) {
status.addError( "null Client or missing ip/uuid in list, ignoring." );
continue;
}
tp.submit( new Runnable() {
public void run()
{
int ret = -1;
try {
status.clientStatus.put( client.machineuuid, ClientStatus.CONNECTING );
Session session = sshClient.getSession( "root", client.clientip, port );
session.connect( 5000 );
ChannelExec channel = (ChannelExec)session.openChannel( "exec" );
String args = " " + minutes + " " + String.format( "'%s'", client.machineuuid.replace( "'", "'\\''" ) );
if ( shutdown ) {
channel.setCommand( SHUTDOWN_CMD + args );
channel.connect( 2000 );
waitForCommand( channel, 2000 );
ret = channel.getExitStatus();
status.clientStatus.put( client.machineuuid, minutes == 0 ? ClientStatus.SHUTDOWN : ClientStatus.SHUTDOWN_AT );
} else {
channel.setCommand( REBOOT_CMD + args );
channel.connect( 2000 );
waitForCommand( channel, 2000 );
ret = channel.getExitStatus();
if ( ret == 0 ) {
status.clientStatus.put( client.machineuuid, minutes == 0 ? ClientStatus.REBOOTING : ClientStatus.REBOOT_AT );
rebootingClients.add( client );
}
}
channel.disconnect();
session.disconnect();
} catch ( JSchException e ) {
if ( e.toString().contains( "Auth fail" ) ) {
status.clientStatus.put( client.machineuuid, ClientStatus.AUTH_FAIL );
ret = 0;
} else {
status.addError( client.clientip + ": " + e.toString() );
ret = -1;
}
}
if ( ret != 0 ) {
if ( ret != -1 ) {
status.addError( client.clientip + ": Exit Code " + ret );
}
status.clientStatus.put( client.machineuuid, ClientStatus.ERROR );
}
}
} );
}
tp.shutdown();
try {
tp.awaitTermination( clients.length * 5, TimeUnit.SECONDS );
} catch ( InterruptedException e ) {
Thread.currentThread().interrupt();
return false;
}
if ( minutes == 0 && rebootingClients.size() > 0 ) {
// Give about 3 minutes for reboot, should be plenty
// Determine online state if either ssh or winrpc/smb is open
final int[] ports;
if ( this.port == 22 ) {
ports = new int[] { 22, 445 };
} else {
ports = new int[] { this.port, 22, 445 };
}
// Assume the boot loop takes at least 30 secs, don't even try before that
try {
Thread.sleep( 30000 );
} catch ( InterruptedException e ) {
Thread.currentThread().interrupt();
return false;
}
long lastcheck = 0;
long deadline = System.currentTimeMillis() + 120 * 1000;
while ( rebootingClients.size() > 0 && System.currentTimeMillis() < deadline ) {
long delay = 10000 - ( System.currentTimeMillis() - lastcheck );
if ( delay > 0 ) {
try {
Thread.sleep( delay );
} catch ( InterruptedException e ) {
Thread.currentThread().interrupt();
return false;
}
}
lastcheck = System.currentTimeMillis();
Iterator<Client> it = rebootingClients.iterator();
while ( it.hasNext() ) {
Client client = it.next();
if ( isOnline( client.clientip, ports ) ) {
it.remove();
status.clientStatus.put( client.machineuuid, ClientStatus.ONLINE );
}
}
}
}
// change status of clients that got stuck because of timeouts
for ( Map.Entry<String, ClientStatus> entry : status.clientStatus.entrySet() ) {
ClientStatus value = entry.getValue();
if ( value == ClientStatus.CONNECTING || ( minutes == 0 && value == ClientStatus.REBOOTING ) ) {
entry.setValue( ClientStatus.ERROR );
}
}
return true;
}
private void waitForCommand( ChannelExec channel, int timeout )
{
for ( int i = 0; i < timeout / 100; i++ ) {
if ( channel.isClosed() ) {
break;
}
try {
Thread.sleep( 100 );
} catch ( InterruptedException e ) {
}
}
}
private boolean isOnline( String address, int... ports )
{
for ( int port : ports ) {
try ( Socket s = new Socket() ) {
s.connect( new InetSocketAddress( address, port ), 1000 );
return true;
} catch ( Exception ex ) {
}
}
return false;
}
/**
* Output - contains additional status data of this task
*/
@SuppressWarnings( "unused" )
static class Output
{
private Map<String, ClientStatus> clientStatus = new ConcurrentHashMap<>();
private Client[] clients;
private String time;
private String locationId;
private String locationName;
private String error;
private synchronized void addError( String e )
{
if ( error == null ) {
error = e + "\n";
} else {
error += e + "\n";
}
}
}
static class Client
{
@Expose
private String machineuuid;
@Expose
private String clientip;
}
static enum ClientStatus
{
CONNECTING,
REBOOTING,
REBOOT_AT,
SHUTDOWN,
SHUTDOWN_AT,
ONLINE,
ERROR,
AUTH_FAIL;
}
}