I've been looking for a cleaner way to implement multiple connections for awhile in Java. I've done it in C with select(), but until now I've had no way to do it in Java. For the benefit of everybody, I'll post the code here.
A complete program, which does 3 HTTP requests and runs an echo server, all in the same thread, can be found here:
http://www.javaop.com/~iago/nonblocking.tgz . It is, of course, released as Public Domain.
If you just want to have a look, here's the code for the main part. It won't compile on its own, because it needs a couple interfaces, but you can still have a look if you want to learn:
<edit> By the way, I should also mention that this is compatible with Java 1.5 and higher. If you look carefully, you'll see I'm using Java's new templating system, which I really love. It makes code WAY cleaner, and gives me a warm feeling.
package util.socket_manager;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.channels.spi.AbstractSelectableChannel;
import java.util.Enumeration;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.LinkedList;
/** This is a fairly generic class to manage multiple outgoing connections. The connections are identified
* by a string, so this is somewhat like a hashtable.
*
* @author iago
*/
public class SocketManager extends Thread
{
/* The timeout length, in milliseconds */
public final static long TIMEOUT = 3000;
/* The maximum size that a single packet can be */
public final static int MAX_PACKET_SIZE = 4096;
/** The selector used for multiplexing the sockets */
private final Selector selector;
/** Contains all connections (servers and clients), indexed by their unique identifiers */
private final Hashtable <String, ConnectionData> sockets_by_identifier;
/** Contains all the client connections, indexed by their SocketChannels */
private final Hashtable <SocketChannel, ConnectionData<ClientCallback, SocketChannel>> client_sockets_by_socket;
/** Contains all the server connecitons, indexed by their ServerSocketchannels */
private final Hashtable <ServerSocketChannel, ConnectionData<ServerCallback, ServerSocketChannel>> server_sockets_by_socket;
/** Create a new instance of ClientManager. It will spawn a new thread to look after arriving data.
*
* @param identifier A way for the user/programmer to distinguish various SocketManager's.
*
* @throws IOException If creating the socket selector failed. If this happens, the program is screwed.
*/
public SocketManager(String identifier) throws IOException
{
sockets_by_identifier = new Hashtable <String, ConnectionData> ();
client_sockets_by_socket = new Hashtable <SocketChannel, ConnectionData<ClientCallback, SocketChannel>>();
server_sockets_by_socket = new Hashtable <ServerSocketChannel, ConnectionData<ServerCallback, ServerSocketChannel>>();
selector = Selector.open();
}
/** This function will do a select operation, which will last no longer than TIMEOUT ms. If data arrives,
* the data will be processed by the appropriate callbacks, followed by this function returning.
* Generally, this should be done in an infinite loop in the core of the program.
*
* @throws IOException If an unexpected I/O error occurs. If a disconnect or timeout occurs, an IOException
* is not thrown, it is handled with the proper callback instead.
*/
public void doSelect() throws IOException
{
if(selector.select(TIMEOUT) == 0)
handleTimeout();
else
handleSockets();
}
/** If a socket is readable, which means there's data waiting to be read, this function is notified.
*
* @param key The readable key.
*/
protected void handleReadable(SelectionKey key)
{
/* Get the channel */
SocketChannel s = (SocketChannel) key.channel();
/* Get the connection data */
ConnectionData <ClientCallback, SocketChannel>cd = client_sockets_by_socket.get(s);
/* The number of bytes read this time */
int bytesRead = -1;
/* Reset the buffer and try to read in data. If there's an IOException, ignore it, the socket will be
* closed when the data fails to read */
try
{
cd.getInBuffer().rewind();
bytesRead = s.read(cd.getInBuffer());
}
catch(IOException e)
{
}
if(bytesRead < 0)
{
/* If no data was read (either due to the socket dying or an IOException occurring), close the socket
* officially. */
close(cd.getIdentifier());
}
else
{
/* Call the callback to process the data, or whatever */
cd.getCallback().receivedData(cd.getInBuffer(), bytesRead);
}
}
/** If a socket is writable, which means that data can be sent, this function is notified.
*
* @param key The writable key.
*/
protected void handleWritable(SelectionKey key)
{
/* Get the channel */
SocketChannel s = (SocketChannel) key.channel();
/* Get the connection data */
ConnectionData <ClientCallback, SocketChannel>cd = client_sockets_by_socket.get(s);
/* Get the next buffer we're planning to send */
ByteBuffer writeData = cd.nextOutBuffer();
try
{
if(writeData == null)
{
/* We only want to read now, so re-register with OP_READ */
s.register(selector, SelectionKey.OP_READ);
}
else
{
/* We can send data, so send it and carry on. */
s.write(writeData);
/* Check if the entire buffer was sent. If it wasn't, add it back to the list */
if(writeData.position() != writeData.limit())
cd.readdOutBuffer(writeData);
}
}
catch(IOException e)
{
/* If an I/O error occurs, close the socket */
close(cd.getIdentifier());
}
}
/** If a socket is acceptable, which means that a connection is pending, this function is notified.
*
* @param key The acceptable key.
*/
protected void handleAcceptable(SelectionKey key)
{
try
{
/* Get the channel */
ServerSocketChannel s = (ServerSocketChannel) key.channel();
/* Get the connection data */
ConnectionData<ServerCallback, ServerSocketChannel> cd = server_sockets_by_socket.get(s);
/* The identifier that the new Client will be registered under */
String identifier = getClientIdentifier(cd.getIdentifier());
/* The new socket for the client */
SocketChannel newSocket = cd.getSocket().accept();
/* The new callback for the client */
ClientCallback callback = cd.getCallback().getSocketCallback(identifier);
ConnectionData<ClientCallback, SocketChannel> cdNew = new ConnectionData<ClientCallback, SocketChannel>(identifier, newSocket, callback);
/* Make it non-blocking and register it as readable */
newSocket.configureBlocking(false);
newSocket.register(selector, SelectionKey.OP_READ);
/* Put it into the hashtables */
sockets_by_identifier.put(identifier, cdNew);
client_sockets_by_socket.put(newSocket, cdNew);
/* Let the callback know that the connection has been established */
cd.getCallback().connectionAccepted();
/* Let the new client know that he's been connected */
callback.connected();
/* Wake up the selector, just to make sure it reads the new socket */
selector.wakeup();
}
catch(IOException e)
{
e.printStackTrace();
}
}
/** This function is called when select() returns, which indicates that something can be done with at least
* one of the selector's keys. This function figures out what has to be done, and calls the functions that
* do the appropriate actions.
*/
protected void handleSockets()
{
Iterator <SelectionKey>it;
SelectionKey key;
/* Get list of selection keys with pending events */
it = selector.selectedKeys().iterator();
/* Process each key */
while (it.hasNext())
{
/* Get the selection key */
key = it.next();
/* Remove it from the list to indicate that it is being processed */
it.remove();
if (key.isReadable())
{
handleReadable(key);
}
else if (key.isWritable())
{
handleWritable(key);
}
else if (key.isAcceptable())
{
handleAcceptable(key);
}
}
}
/** If the TIMEOUT time elapses, this function is called. It loops through each socket in all the lists
* and informs it that it's timed out.
*/
protected void handleTimeout()
{
/* This probably means a timeout has occurred.. do the keepalive or whatever */
for(Enumeration<String> e = sockets_by_identifier.keys(); e.hasMoreElements(); )
sockets_by_identifier.get(e.nextElement()).getCallback().timeout();
}
/** Remove/close the specified connection from the list, if possible, and connect to the host, storing
* the connection. This function doesn't return until the connection is successful.
* @throws IOException If the connect failed.
*/
public void connect(String host, int port, String identifier, ClientCallback callback) throws IOException
{
SocketChannel s;
ConnectionData <ClientCallback, SocketChannel>cd;
/* Close it, just in case it already exists */
close(identifier);
/* Open the new channel */
s = SocketChannel.open();
/* Turn off blocking */
s.configureBlocking(false);
/* Connect to the remote socket */
s.connect(new InetSocketAddress(host, port));
/* Wait for the connect to finish (TODO: Get rid of this?) */
while(!s.finishConnect())
;
/* Register the new socket, looking for reading */
s.register(selector, SelectionKey.OP_READ);
/* Create a new ConnectionData structure */
cd = new ConnectionData <ClientCallback, SocketChannel>(identifier, s, callback);
/* Put it into the hashtables */
sockets_by_identifier.put(identifier, cd);
client_sockets_by_socket.put(s, cd);
/* Let the callback know that the connection has been established */
callback.connected();
/* Wake up the selector, just to make sure it reads the new socket */
selector.wakeup();
}
/** Listens for new incoming connections, adding the listening socket to the list of servers.
*
* @param port The port to litsen on. Remember, on Linux, this has to be >=1024 unless you're root.
* @param identifier The name to give the server, for the purposes of identification.
* @param callback The callback to inform when something happens.
* @throws IOException If the port couldn't be opened for listening.
*/
public void listen(int port, String identifier, ServerCallback callback) throws IOException
{
ServerSocketChannel s = ServerSocketChannel.open();
ConnectionData<ServerCallback, ServerSocketChannel> cd;
s.configureBlocking(false);
s.socket().bind(new InetSocketAddress(port));
s.register(selector, SelectionKey.OP_ACCEPT);
cd = new ConnectionData<ServerCallback, ServerSocketChannel>(identifier, s, callback);
sockets_by_identifier.put(identifier, cd);
server_sockets_by_socket.put(s, cd);
}
/** Remove/close the specified connection, if possible. If it fails, no exception is thrown, it's just
* ignored.
* @param identifier The sting that identifies the connection.
*/
public void close(String identifier)
{
ConnectionData cd;
/* Get the socket from the hashtable, if possible */
cd = sockets_by_identifier.get(identifier);
/* If the socket exists, close it */
if(cd != null)
{
try
{
cd.getSocket().close();
}
catch(IOException e)
{
}
cd.getCallback().disconnected();
/* Remove the socket from the list (whether it's a client or a server) */
client_sockets_by_socket.remove(cd.getSocket());
server_sockets_by_socket.remove(cd.getSocket());
/* Remove the identifier fromt he list */
sockets_by_identifier.remove(identifier);
}
/* Wake up the selector to make sure it's removed properly */
selector.wakeup();
}
/** Writes the requested bytes to the socket whenever possible. The data will be queued, and when the
* socket becomes writable (probably immediately), it is sent.
*
* @param identifier The socket identifier to send the data over.
* @param bytes The data to send.
* @throws IOException If something goes terribly wrong. Isn't thrown if the socket is disconnected.
* @throws IllegalArgumentException If the identifier isn't found in the list.
*/
public void send(String identifier, ByteBuffer bytes)
{
ConnectionData cd;
cd = sockets_by_identifier.get(identifier);
if(cd != null)
{
try
{
cd.addOutBuffer(bytes);
/* Let it know that there is data that should be sent out now */
cd.getSocket().register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE);
/* Poke it so it sees the data */
selector.wakeup();
}
catch(ClosedChannelException e)
{
close(identifier);
}
}
else
{
throw new IllegalArgumentException("No such identifier: " + identifier);
}
}
/** Generates a client identifier from a server name. This can be changed to suit the application, but
* I personally just tag the next available number on the end.
* WARNING: the identifier number can be recycled, when a socket is closed.
*
* @param serverIdentifier The identifier for the server that the new name might be based on.
* @return A unique name that can be used to identify the socket.
*/
protected String getClientIdentifier(String serverIdentifier)
{
String identifier;
int i = 0;
do
{
identifier = serverIdentifier + "-client-" + i;
i++;
}
while(sockets_by_identifier.get(identifier) != null);
return identifier;
}
}
/** This local class manages data for a single connection. The purpose is to keep all the data together. */
class ConnectionData <CallbackType extends SocketCallback, SocketType extends AbstractSelectableChannel>
{
private final String identifier;
private final SocketType s;
private final CallbackType callback;
private final ByteBuffer inBuffer;
private final LinkedList <ByteBuffer> outBuffers;
public ConnectionData (String identifier, SocketType s, CallbackType callback)
{
this.identifier = identifier;
this.s = s;
this.callback = callback;
this.inBuffer = ByteBuffer.allocate(SocketManager.MAX_PACKET_SIZE);
outBuffers = new LinkedList <ByteBuffer>();
}
public String getIdentifier()
{
return identifier;
}
public SocketType getSocket()
{
return s;
}
public CallbackType getCallback()
{
return callback;
}
public ByteBuffer getInBuffer()
{
return inBuffer;
}
/** Adds a buffer that will be sent out. The buffer is rewound before adding. */
public synchronized void addOutBuffer(ByteBuffer buffer)
{
buffer.rewind();
outBuffers.addLast(buffer);
}
/** Re-adds a buffer who didn't manager to send all its data. It is put at the front so it has the
* first chance to send next time, and it obviously isn't rewound.
*/
public synchronized void readdOutBuffer(ByteBuffer buffer)
{
outBuffers.addFirst(buffer);
}
public synchronized ByteBuffer nextOutBuffer()
{
if(outBuffers.size() == 0)
return null;
return outBuffers.removeFirst();
}
}