TlsChannelStrategy.java
/*
* @copyright defined in LICENSE.txt
*/
package hera.strategy;
import static hera.util.ValidationUtils.assertNotNull;
import static org.slf4j.LoggerFactory.getLogger;
import hera.exception.RpcException;
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
import io.grpc.okhttp.OkHttpChannelBuilder;
import java.io.InputStream;
import javax.net.ssl.SSLSocketFactory;
import lombok.ToString;
import org.slf4j.Logger;
@ToString(exclude = "logger")
public class TlsChannelStrategy implements SecurityConfigurationStrategy {
protected final Logger logger = getLogger(getClass());
protected final String serverCommonName;
protected final InputStream serverCertInputStream;
protected final InputStream clientCertInputStream;
protected final InputStream clientKeyInputStream;
/**
* TlsChannelStrategy constructor.
*
* @param serverCommonName a server common name (CN)
* @param serverCertInputStream a server certification input stream
* @param clientCertInputStream a client certification input stream
* @param clientKeyInputStream a server key input stream
*/
public TlsChannelStrategy(final String serverCommonName, final InputStream serverCertInputStream,
final InputStream clientCertInputStream, final InputStream clientKeyInputStream) {
assertNotNull(serverCommonName, "Server common name must not null");
assertNotNull(serverCertInputStream, "Server cert input stream must not null");
assertNotNull(clientCertInputStream, "Client cert input stream must not null");
assertNotNull(clientKeyInputStream, "Client key input stream must not null");
this.serverCommonName = serverCommonName;
this.serverCertInputStream = serverCertInputStream;
this.clientCertInputStream = clientCertInputStream;
this.clientKeyInputStream = clientKeyInputStream;
}
@Override
public void configure(final ManagedChannelBuilder<?> builder) {
logger.debug(
"Configure cls with serverCertStream: {}, clientCertStream: {}, clientKeyStream: {}",
serverCertInputStream, clientCertInputStream, clientKeyInputStream);
try {
if (builder instanceof NettyChannelBuilder) {
final SslContext sslContext = GrpcSslContexts.forClient()
.trustManager(serverCertInputStream)
.keyManager(clientCertInputStream, clientKeyInputStream)
.build();
((NettyChannelBuilder) builder).sslContext(sslContext);
} else if (builder instanceof OkHttpChannelBuilder) {
// TODO : not yet implemented
final SSLSocketFactory sslSocketFactory = null;
((OkHttpChannelBuilder) builder).sslSocketFactory(sslSocketFactory);
} else {
throw new RpcException("Unsupported channel builder type " + builder.getClass());
}
builder.overrideAuthority(serverCommonName).useTransportSecurity();
} catch (final RpcException e) {
throw e;
} catch (final Exception e) {
throw new RpcException(e);
}
}
@Override
public boolean equals(final Object obj) {
return (null != obj) && (obj instanceof SecurityConfigurationStrategy);
}
@Override
public int hashCode() {
return TlsChannelStrategy.class.hashCode();
}
}