Skip to content

Commit

Permalink
Add OrcSchemaConversionValidator to avoid infinite recursion in AvroO…
Browse files Browse the repository at this point in the history
…rcSchemaConverter.getOrcSchema()
  • Loading branch information
Tao Qin committed Oct 18, 2023
1 parent 6266a12 commit 55b8951
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gobblin.source.extractor.extract.kafka.validator;

import java.io.IOException;
import org.apache.avro.Schema;
import org.apache.gobblin.configuration.State;
import org.apache.gobblin.kafka.schemareg.KafkaSchemaRegistry;
import org.apache.gobblin.kafka.schemareg.KafkaSchemaRegistryFactory;
import org.apache.gobblin.kafka.schemareg.SchemaRegistryException;
import org.apache.gobblin.source.extractor.extract.kafka.KafkaTopic;
import org.apache.gobblin.util.orc.AvroOrcSchemaConverter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


public class OrcSchemaConversionValidator extends TopicValidatorBase {
private static final Logger LOGGER = LoggerFactory.getLogger(OrcSchemaConversionValidator.class);

public static final String MAX_RECURSIVE_DEPTH_KEY = "gobblin.kafka.topicValidators.orcSchemaConversionValidator.maxRecursiveDepth";
public static final int DEFAULT_MAX_RECURSIVE_DEPTH = 200;

private final KafkaSchemaRegistry schemaRegistry;

public OrcSchemaConversionValidator(State sourceState) {
super(sourceState);
this.schemaRegistry = KafkaSchemaRegistryFactory.getSchemaRegistry(sourceState.getProperties());
}

@Override
public boolean validate(KafkaTopic topic) throws Exception {
LOGGER.debug("Validating ORC schema conversion for topic {}", topic.getName());
try {
Schema schema = (Schema) this.schemaRegistry.getLatestSchema(topic.getName());
// Try converting the avro schema to orc schema to check if any errors.
int maxRecursiveDepth = this.state.getPropAsInt(MAX_RECURSIVE_DEPTH_KEY, DEFAULT_MAX_RECURSIVE_DEPTH);
AvroOrcSchemaConverter.tryGetOrcSchema(schema, 0, maxRecursiveDepth);
} catch (StackOverflowError e) {
LOGGER.warn("Failed to covert latest schema to ORC schema for topic: {}", topic.getName());
return false;
} catch (IOException | SchemaRegistryException e) {
LOGGER.warn("Failed to get latest schema for topic: {}, validation is skipped, exception: ", topic.getName(), e);
return true;
}
return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gobblin.source.extractor.extract.kafka.validator;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.io.IOException;
import java.util.Map;
import java.util.Properties;
import org.apache.avro.Schema;
import org.apache.gobblin.configuration.State;
import org.apache.gobblin.kafka.schemareg.KafkaSchemaRegistry;
import org.apache.gobblin.kafka.schemareg.KafkaSchemaRegistryConfigurationKeys;
import org.apache.gobblin.kafka.schemareg.SchemaRegistryException;
import org.apache.gobblin.kafka.serialize.MD5Digest;
import org.apache.gobblin.source.extractor.extract.kafka.KafkaTopic;
import org.testng.Assert;
import org.testng.annotations.Test;


public class OrcSchemaConversionValidatorTest {
@Test
public void testOrcSchemaConversionValidator() throws Exception {
// topic 1's schema has max depth = 1
KafkaTopic topic1 = new KafkaTopic("topic1", ImmutableList.of());
// topic 2's schema has max depth = 2
KafkaTopic topic2 = new KafkaTopic("topic2", ImmutableList.of());
// topic 3's schema has recursive filed reference
KafkaTopic topic3 = new KafkaTopic("topic3", ImmutableList.of());

State state = new State();
// Use the test schema registry to get the schema for the above topic.
state.setProp(KafkaSchemaRegistryConfigurationKeys.KAFKA_SCHEMA_REGISTRY_CLASS, TestKafkaSchemaRegistry.class.getName());

// Validate with default max_recursive_depth (=200).
OrcSchemaConversionValidator validator = new OrcSchemaConversionValidator(state);
Assert.assertTrue(validator.validate(topic1)); // Pass validation
Assert.assertTrue(validator.validate(topic2)); // Pass validation
Assert.assertFalse(validator.validate(topic3)); // Fail validation, default max_recursive_depth = 200, the validation returns early

// Validate with max_recursive_depth=1
state.setProp(OrcSchemaConversionValidator.MAX_RECURSIVE_DEPTH_KEY, 1);
Assert.assertTrue(validator.validate(topic1)); // Pass validation
Assert.assertFalse(validator.validate(topic2)); // Fail validation, because max_recursive_depth is set to 1, the validation returns early
Assert.assertFalse(validator.validate(topic3)); // Fail validation, because max_recursive_depth is set to 1, the validation returns early
}

@Test
public void testGetLatestSchemaFail() throws Exception {
KafkaTopic topic1 = new KafkaTopic("topic1", ImmutableList.of());
KafkaTopic topic2 = new KafkaTopic("topic2", ImmutableList.of());
KafkaTopic topic3 = new KafkaTopic("topic3", ImmutableList.of());
State state = new State();
state.setProp(KafkaSchemaRegistryConfigurationKeys.KAFKA_SCHEMA_REGISTRY_CLASS, BadKafkaSchemaRegistry.class.getName());

OrcSchemaConversionValidator validator = new OrcSchemaConversionValidator(state);
// Validator should always return PASS when it fails to get latest schema.
Assert.assertTrue(validator.validate(topic1));
Assert.assertTrue(validator.validate(topic2));
Assert.assertTrue(validator.validate(topic3));
}

// A KafkaSchemaRegistry class that returns the hardcoded schemas for the test topics.
public static class TestKafkaSchemaRegistry implements KafkaSchemaRegistry<MD5Digest, Schema> {
private final String schemaMaxInnerFieldDepthIs1 = "{"
+ "\"type\": \"record\","
+ " \"name\": \"test\","
+ " \"fields\": ["
+ " {\n"
+ " \"name\": \"id\","
+ " \"type\": \"int\""
+ " },"
+ " {"
+ " \"name\": \"timestamp\","
+ " \"type\": \"string\""
+ " }"
+ " ]"
+ "}";

private final String schemaMaxInnerFieldDepthIs2 = "{"
+ " \"type\": \"record\","
+ " \"name\": \"nested\","
+ " \"fields\": ["
+ " {"
+ " \"name\": \"nestedId\","
+ " \"type\": {\n"
+ " \"type\": \"array\","
+ " \"items\": \"string\""
+ " }"
+ " },"
+ " {"
+ " \"name\": \"timestamp\","
+ " \"type\": \"string\""
+ " }"
+ " ]"
+ "}";

private final String schemaWithRecursiveRef = "{"
+ " \"type\": \"record\","
+ " \"name\": \"TreeNode\","
+ " \"fields\": ["
+ " {"
+ " \"name\": \"value\","
+ " \"type\": \"int\""
+ " },"
+ " {"
+ " \"name\": \"children\","
+ " \"type\": {"
+ " \"type\": \"array\","
+ " \"items\": \"TreeNode\""
+ " }"
+ " }"
+ " ]"
+ "}";
private final Map<String, Schema> topicToSchema;

public TestKafkaSchemaRegistry(Properties props) {
topicToSchema = ImmutableMap.of(
"topic1", new Schema.Parser().parse(schemaMaxInnerFieldDepthIs1),
"topic2", new Schema.Parser().parse(schemaMaxInnerFieldDepthIs2),
"topic3", new Schema.Parser().parse(schemaWithRecursiveRef));
}
@Override
public Schema getLatestSchema(String topicName) {
return topicToSchema.get(topicName);
}

@Override
public MD5Digest register(String name, Schema schema) {
return null;
}

@Override
public Schema getById(MD5Digest id) {
return null;
}

@Override
public boolean hasInternalCache() {
return false;
}
}

// A KafkaSchemaRegistry class that always fail to get latest schema.
public static class BadKafkaSchemaRegistry implements KafkaSchemaRegistry<MD5Digest, Schema> {
public BadKafkaSchemaRegistry(Properties props) {
}

@Override
public Schema getLatestSchema(String name) throws IOException, SchemaRegistryException {
throw new SchemaRegistryException("Exception in getLatestSchema()");
}

@Override
public MD5Digest register(String name, Schema schema) throws IOException, SchemaRegistryException {
return null;
}

@Override
public Schema getById(MD5Digest id) throws IOException, SchemaRegistryException {
return null;
}

@Override
public boolean hasInternalCache() {
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ public void testValidatorTimeout() {
Assert.assertEquals(validTopics.get(0).getName(), "topic2");
}

@Test
public void testValidatorThrowingException() {
List<String> allTopics = Arrays.asList("topic1", "topic2");
List<KafkaTopic> topics = buildKafkaTopics(allTopics);
State state = new State();
state.setProp(TopicValidators.VALIDATOR_CLASSES_KEY, ValidatorThrowingException.class.getName());
List<KafkaTopic> validTopics = new TopicValidators(state).validate(topics);

Assert.assertEquals(validTopics.size(), 2); // validator throws exceptions, so all topics are treated as valid
Assert.assertTrue(validTopics.stream().anyMatch(topic -> topic.getName().equals("topic1")));
Assert.assertTrue(validTopics.stream().anyMatch(topic -> topic.getName().equals("topic2")));
}

private List<KafkaTopic> buildKafkaTopics(List<String> topics) {
return topics.stream()
.map(topicName -> new KafkaTopic(topicName, Collections.emptyList()))
Expand Down Expand Up @@ -109,4 +122,15 @@ public boolean validate(KafkaTopic topic) {
return false;
}
}

public static class ValidatorThrowingException extends TopicValidatorBase {
public ValidatorThrowingException(State state) {
super(state);
}

@Override
public boolean validate(KafkaTopic topic) throws Exception {
throw new Exception("Always throw exception");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,24 @@
* A utility class that provides a method to convert {@link Schema} into {@link TypeDescription}.
*/
public class AvroOrcSchemaConverter {
// Convert avro schema to orc schema, calling tryGetOrcSchema without recursive depth limit for backward compatibility
public static TypeDescription getOrcSchema(Schema avroSchema) {
return tryGetOrcSchema(avroSchema, 0, Integer.MAX_VALUE - 1);
}

/**
* Try converting the avro schema into {@link TypeDescription}, but with max recursive depth to avoid stack overflow.
* A typical use case is the topic validation during work unit creation.
* @param avroSchema The avro schema to convert
* @param currentDepth Current depth of the recursive call
* @param maxDepth Max depth of the recursive call
* @return the converted {@link TypeDescription}
*/
public static TypeDescription tryGetOrcSchema(Schema avroSchema, int currentDepth, int maxDepth)
throws StackOverflowError {
if (currentDepth == maxDepth + 1) {
throw new StackOverflowError("Recursive call of tryGetOrcSchema() reaches max depth " + maxDepth);
}

final Schema.Type type = avroSchema.getType();
switch (type) {
Expand All @@ -43,12 +60,12 @@ public static TypeDescription getOrcSchema(Schema avroSchema) {
case FIXED:
return getTypeDescriptionForBinarySchema(avroSchema);
case ARRAY:
return TypeDescription.createList(getOrcSchema(avroSchema.getElementType()));
return TypeDescription.createList(tryGetOrcSchema(avroSchema.getElementType(), currentDepth + 1, maxDepth));
case RECORD:
final TypeDescription recordStruct = TypeDescription.createStruct();
for (Schema.Field field2 : avroSchema.getFields()) {
final Schema fieldSchema = field2.schema();
final TypeDescription fieldType = getOrcSchema(fieldSchema);
final TypeDescription fieldType = tryGetOrcSchema(fieldSchema, currentDepth + 1, maxDepth);
if (fieldType != null) {
recordStruct.addField(field2.name(), fieldType);
} else {
Expand All @@ -59,19 +76,19 @@ public static TypeDescription getOrcSchema(Schema avroSchema) {
case MAP:
return TypeDescription.createMap(
// in Avro maps, keys are always strings
TypeDescription.createString(), getOrcSchema(avroSchema.getValueType()));
TypeDescription.createString(), tryGetOrcSchema(avroSchema.getValueType(), currentDepth + 1, maxDepth));
case UNION:
final List<Schema> nonNullMembers = getNonNullMembersOfUnion(avroSchema);
if (isNullableUnion(avroSchema, nonNullMembers)) {
// a single non-null union member
// this is how Avro represents "nullable" types; as a union of the NULL type with another
// since ORC already supports nullability of all types, just use the child type directly
return getOrcSchema(nonNullMembers.get(0));
return tryGetOrcSchema(nonNullMembers.get(0), currentDepth + 1, maxDepth);
} else {
// not a nullable union type; represent as an actual ORC union of them
final TypeDescription union = TypeDescription.createUnion();
for (final Schema childSchema : nonNullMembers) {
union.addUnionChild(getOrcSchema(childSchema));
union.addUnionChild(tryGetOrcSchema(childSchema, currentDepth + 1, maxDepth));
}
return union;
}
Expand Down

0 comments on commit 55b8951

Please sign in to comment.