Skip to content

Commit

Permalink
Add overloaded init methods that take the public key from a stream an… (
Browse files Browse the repository at this point in the history
hierynomus#908)

* Add overloaded init methods that take the public key from a stream and properly initialize. Resolves hierynomus#907.

* Override public key.
  • Loading branch information
dkocher authored Apr 29, 2024
1 parent 607e805 commit 09e2ca5
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.hierynomus.sshj.transport.cipher.BlockCiphers;
import com.hierynomus.sshj.transport.cipher.ChachaPolyCiphers;
import com.hierynomus.sshj.transport.cipher.GcmCiphers;
import com.hierynomus.sshj.userauth.keyprovider.bcrypt.BCrypt;
import net.i2p.crypto.eddsa.EdDSAPrivateKey;
import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable;
import net.i2p.crypto.eddsa.spec.EdDSAPrivateKeySpec;
Expand All @@ -29,24 +30,23 @@
import net.schmizz.sshj.userauth.keyprovider.BaseFileKeyProvider;
import net.schmizz.sshj.userauth.keyprovider.FileKeyProvider;
import net.schmizz.sshj.userauth.keyprovider.KeyFormat;
import net.schmizz.sshj.userauth.password.PasswordFinder;
import org.bouncycastle.asn1.nist.NISTNamedCurves;
import org.bouncycastle.asn1.x9.X9ECParameters;
import org.bouncycastle.jce.spec.ECNamedCurveSpec;
import com.hierynomus.sshj.userauth.keyprovider.bcrypt.BCrypt;
import org.bouncycastle.openssl.EncryptionException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.io.*;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.StandardCharsets;
import java.security.*;
import java.security.GeneralSecurityException;
import java.security.KeyPair;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.spec.ECPrivateKeySpec;
import java.security.spec.RSAPrivateCrtKeySpec;
import java.util.Arrays;
Expand Down Expand Up @@ -83,6 +83,12 @@ public class OpenSSHKeyV1KeyFile extends BaseFileKeyProvider {

private PublicKey pubKey;

@Override
public PublicKey getPublic()
throws IOException {
return pubKey != null ? pubKey : super.getPublic();
}

public static class Factory
implements net.schmizz.sshj.common.Factory.Named<FileKeyProvider> {

Expand All @@ -100,16 +106,41 @@ public String getName() {
protected final Logger log = LoggerFactory.getLogger(getClass());

@Override
public void init(File location) {
public void init(File location, PasswordFinder pwdf) {
File pubKey = OpenSSHKeyFileUtil.getPublicKeyFile(location);
if (pubKey != null)
if (pubKey != null) {
try {
initPubKey(new FileReader(pubKey));
} catch (IOException e) {
// let super provide both public & private key
log.warn("Error reading public key file: {}", e.toString());
}
super.init(location);
}
super.init(location, pwdf);
}

@Override
public void init(String privateKey, String publicKey, PasswordFinder pwdf) {
if (pubKey != null) {
try {
initPubKey(new StringReader(publicKey));
} catch (IOException e) {
log.warn("Error reading public key file: {}", e.toString());
}
}
super.init(privateKey, null, pwdf);
}

@Override
public void init(Reader privateKey, Reader publicKey, PasswordFinder pwdf) {
if (pubKey != null) {
try {
initPubKey(publicKey);
} catch (IOException e) {
log.warn("Error reading public key file: {}", e.toString());
}
}
super.init(privateKey, null, pwdf);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,38 +34,47 @@ public abstract class BaseFileKeyProvider implements FileKeyProvider {

@Override
public void init(Reader location) {
assert location != null;
resource = new PrivateKeyReaderResource(location);
this.init(location, (PasswordFinder) null);
}

@Override
public void init(Reader location, PasswordFinder pwdf) {
init(location);
this.init(location, null, pwdf);
}

@Override
public void init(Reader privateKey, Reader publicKey) {
this.init(privateKey, publicKey, null);
}

@Override
public void init(Reader privateKey, Reader publicKey, PasswordFinder pwdf) {
assert publicKey == null;
this.resource = new PrivateKeyReaderResource(privateKey);
this.pwdf = pwdf;
}

@Override
public void init(File location) {
assert location != null;
resource = new PrivateKeyFileResource(location.getAbsoluteFile());
this.init(location, null);
}

@Override
public void init(File location, PasswordFinder pwdf) {
init(location);
this.resource = new PrivateKeyFileResource(location.getAbsoluteFile());
this.pwdf = pwdf;
}

@Override
public void init(String privateKey, String publicKey) {
assert privateKey != null;
assert publicKey == null;
resource = new PrivateKeyStringResource(privateKey);
this.init(privateKey, publicKey, null);
}

@Override
public void init(String privateKey, String publicKey, PasswordFinder pwdf) {
init(privateKey, publicKey);
assert privateKey != null;
assert publicKey == null;
this.resource = new PrivateKeyStringResource(privateKey);
this.pwdf = pwdf;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ public interface FileKeyProvider

void init(Reader location);

void init(Reader privateKey, Reader publicKey);

void init(Reader privateKey, Reader publicKey, PasswordFinder pwdf);

void init(Reader location, PasswordFinder pwdf);

void init(String privateKey, String publicKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package net.schmizz.sshj.userauth.keyprovider;

import com.hierynomus.sshj.userauth.keyprovider.OpenSSHKeyFileUtil;
import net.schmizz.sshj.userauth.password.PasswordFinder;

import java.io.*;
import java.security.PublicKey;
Expand Down Expand Up @@ -54,21 +55,22 @@ public PublicKey getPublic()
}

@Override
public void init(File location) {
public void init(File location, PasswordFinder pwdf) {
// try cert key location first
File pubKey = OpenSSHKeyFileUtil.getPublicKeyFile(location);
if (pubKey != null)
if (pubKey != null) {
try {
initPubKey(new FileReader(pubKey));
} catch (IOException e) {
// let super provide both public & private key
log.warn("Error reading public key file: {}", e.toString());
}
super.init(location);
}
super.init(location, pwdf);
}

@Override
public void init(String privateKey, String publicKey) {
public void init(String privateKey, String publicKey, PasswordFinder pwdf) {
if (publicKey != null) {
try {
initPubKey(new StringReader(publicKey));
Expand All @@ -77,7 +79,20 @@ public void init(String privateKey, String publicKey) {
log.warn("Error reading public key: {}", e.toString());
}
}
super.init(privateKey, null);
super.init(privateKey, null, pwdf);
}

@Override
public void init(Reader privateKey, Reader publicKey, PasswordFinder pwdf) {
if (publicKey != null) {
try {
initPubKey(publicKey);
} catch (IOException e) {
// let super provide both public & private key
log.warn("Error reading public key: {}", e.toString());
}
}
super.init(privateKey, null, pwdf);
}

/**
Expand Down
23 changes: 23 additions & 0 deletions src/test/java/net/schmizz/sshj/keyprovider/OpenSSHKeyFileTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,18 @@ public void shouldSuccessfullyLoadSignedRSAPublicKey() throws IOException {

}

@Test
public void shouldSuccessfullyLoadSignedRSAPublicKeyFromStream() throws IOException {
FileKeyProvider keyFile = new OpenSSHKeyFile();
keyFile.init(new FileReader("src/test/resources/keytypes/certificate/test_rsa"),
new FileReader("src/test/resources/keytypes/certificate/test_rsa.pub"),
PasswordUtils.createOneOff(correctPassphrase));
assertNotNull(keyFile.getPrivate());
PublicKey pubKey = keyFile.getPublic();
assertNotNull(pubKey);
assertEquals("RSA", pubKey.getAlgorithm());
}

@Test
public void shouldSuccessfullyLoadSignedRSAPublicKeyWithMaxDate() throws IOException {
FileKeyProvider keyFile = new OpenSSHKeyFile();
Expand Down Expand Up @@ -422,6 +434,17 @@ public void shouldSuccessfullyLoadSignedDSAPublicKey() throws IOException {
assertEquals("", certificate.getExtensions().get("permit-pty"));
}

@Test
public void shouldSuccessfullyLoadSignedDSAPublicKeyFromStream() throws IOException {
FileKeyProvider keyFile = new OpenSSHKeyFile();
keyFile.init(new FileReader("src/test/resources/keytypes/certificate/test_dsa"),
new FileReader("src/test/resources/keytypes/certificate/test_dsa-cert.pub"),
PasswordUtils.createOneOff(correctPassphrase));
assertNotNull(keyFile.getPrivate());
PublicKey pubKey = keyFile.getPublic();
assertEquals("DSA", pubKey.getAlgorithm());
}

/**
* Sometimes users copy-pastes private and public keys in text editors. It leads to redundant
* spaces and newlines. OpenSSH can easily read such keys, so users expect from SSHJ the same.
Expand Down

0 comments on commit 09e2ca5

Please sign in to comment.