From 55b8951fd850b0ec5dd5c5b91913ea0d9a72027c Mon Sep 17 00:00:00 2001 From: Tao Qin Date: Tue, 3 Oct 2023 17:48:59 -0700 Subject: [PATCH] Add OrcSchemaConversionValidator to avoid infinite recursion in AvroOrcSchemaConverter.getOrcSchema() --- .../OrcSchemaConversionValidator.java | 61 ++++++ .../OrcSchemaConversionValidatorTest.java | 183 ++++++++++++++++++ .../kafka/validator/TopicValidatorsTest.java | 24 +++ .../util/orc/AvroOrcSchemaConverter.java | 27 ++- 4 files changed, 290 insertions(+), 5 deletions(-) create mode 100644 gobblin-modules/gobblin-kafka-common/src/main/java/org/apache/gobblin/source/extractor/extract/kafka/validator/OrcSchemaConversionValidator.java create mode 100644 gobblin-modules/gobblin-kafka-common/src/test/java/org/apache/gobblin/source/extractor/extract/kafka/validator/OrcSchemaConversionValidatorTest.java diff --git a/gobblin-modules/gobblin-kafka-common/src/main/java/org/apache/gobblin/source/extractor/extract/kafka/validator/OrcSchemaConversionValidator.java b/gobblin-modules/gobblin-kafka-common/src/main/java/org/apache/gobblin/source/extractor/extract/kafka/validator/OrcSchemaConversionValidator.java new file mode 100644 index 0000000000..c5e4837d2a --- /dev/null +++ b/gobblin-modules/gobblin-kafka-common/src/main/java/org/apache/gobblin/source/extractor/extract/kafka/validator/OrcSchemaConversionValidator.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gobblin.source.extractor.extract.kafka.validator; + +import java.io.IOException; +import org.apache.avro.Schema; +import org.apache.gobblin.configuration.State; +import org.apache.gobblin.kafka.schemareg.KafkaSchemaRegistry; +import org.apache.gobblin.kafka.schemareg.KafkaSchemaRegistryFactory; +import org.apache.gobblin.kafka.schemareg.SchemaRegistryException; +import org.apache.gobblin.source.extractor.extract.kafka.KafkaTopic; +import org.apache.gobblin.util.orc.AvroOrcSchemaConverter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +public class OrcSchemaConversionValidator extends TopicValidatorBase { + private static final Logger LOGGER = LoggerFactory.getLogger(OrcSchemaConversionValidator.class); + + public static final String MAX_RECURSIVE_DEPTH_KEY = "gobblin.kafka.topicValidators.orcSchemaConversionValidator.maxRecursiveDepth"; + public static final int DEFAULT_MAX_RECURSIVE_DEPTH = 200; + + private final KafkaSchemaRegistry schemaRegistry; + + public OrcSchemaConversionValidator(State sourceState) { + super(sourceState); + this.schemaRegistry = KafkaSchemaRegistryFactory.getSchemaRegistry(sourceState.getProperties()); + } + + @Override + public boolean validate(KafkaTopic topic) throws Exception { + LOGGER.debug("Validating ORC schema conversion for topic {}", topic.getName()); + try { + Schema schema = (Schema) this.schemaRegistry.getLatestSchema(topic.getName()); + // Try converting the avro schema to orc schema to check if any errors. + int maxRecursiveDepth = this.state.getPropAsInt(MAX_RECURSIVE_DEPTH_KEY, DEFAULT_MAX_RECURSIVE_DEPTH); + AvroOrcSchemaConverter.tryGetOrcSchema(schema, 0, maxRecursiveDepth); + } catch (StackOverflowError e) { + LOGGER.warn("Failed to covert latest schema to ORC schema for topic: {}", topic.getName()); + return false; + } catch (IOException | SchemaRegistryException e) { + LOGGER.warn("Failed to get latest schema for topic: {}, validation is skipped, exception: ", topic.getName(), e); + return true; + } + return true; + } +} diff --git a/gobblin-modules/gobblin-kafka-common/src/test/java/org/apache/gobblin/source/extractor/extract/kafka/validator/OrcSchemaConversionValidatorTest.java b/gobblin-modules/gobblin-kafka-common/src/test/java/org/apache/gobblin/source/extractor/extract/kafka/validator/OrcSchemaConversionValidatorTest.java new file mode 100644 index 0000000000..e077e1a074 --- /dev/null +++ b/gobblin-modules/gobblin-kafka-common/src/test/java/org/apache/gobblin/source/extractor/extract/kafka/validator/OrcSchemaConversionValidatorTest.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gobblin.source.extractor.extract.kafka.validator; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.io.IOException; +import java.util.Map; +import java.util.Properties; +import org.apache.avro.Schema; +import org.apache.gobblin.configuration.State; +import org.apache.gobblin.kafka.schemareg.KafkaSchemaRegistry; +import org.apache.gobblin.kafka.schemareg.KafkaSchemaRegistryConfigurationKeys; +import org.apache.gobblin.kafka.schemareg.SchemaRegistryException; +import org.apache.gobblin.kafka.serialize.MD5Digest; +import org.apache.gobblin.source.extractor.extract.kafka.KafkaTopic; +import org.testng.Assert; +import org.testng.annotations.Test; + + +public class OrcSchemaConversionValidatorTest { + @Test + public void testOrcSchemaConversionValidator() throws Exception { + // topic 1's schema has max depth = 1 + KafkaTopic topic1 = new KafkaTopic("topic1", ImmutableList.of()); + // topic 2's schema has max depth = 2 + KafkaTopic topic2 = new KafkaTopic("topic2", ImmutableList.of()); + // topic 3's schema has recursive filed reference + KafkaTopic topic3 = new KafkaTopic("topic3", ImmutableList.of()); + + State state = new State(); + // Use the test schema registry to get the schema for the above topic. + state.setProp(KafkaSchemaRegistryConfigurationKeys.KAFKA_SCHEMA_REGISTRY_CLASS, TestKafkaSchemaRegistry.class.getName()); + + // Validate with default max_recursive_depth (=200). + OrcSchemaConversionValidator validator = new OrcSchemaConversionValidator(state); + Assert.assertTrue(validator.validate(topic1)); // Pass validation + Assert.assertTrue(validator.validate(topic2)); // Pass validation + Assert.assertFalse(validator.validate(topic3)); // Fail validation, default max_recursive_depth = 200, the validation returns early + + // Validate with max_recursive_depth=1 + state.setProp(OrcSchemaConversionValidator.MAX_RECURSIVE_DEPTH_KEY, 1); + Assert.assertTrue(validator.validate(topic1)); // Pass validation + Assert.assertFalse(validator.validate(topic2)); // Fail validation, because max_recursive_depth is set to 1, the validation returns early + Assert.assertFalse(validator.validate(topic3)); // Fail validation, because max_recursive_depth is set to 1, the validation returns early + } + + @Test + public void testGetLatestSchemaFail() throws Exception { + KafkaTopic topic1 = new KafkaTopic("topic1", ImmutableList.of()); + KafkaTopic topic2 = new KafkaTopic("topic2", ImmutableList.of()); + KafkaTopic topic3 = new KafkaTopic("topic3", ImmutableList.of()); + State state = new State(); + state.setProp(KafkaSchemaRegistryConfigurationKeys.KAFKA_SCHEMA_REGISTRY_CLASS, BadKafkaSchemaRegistry.class.getName()); + + OrcSchemaConversionValidator validator = new OrcSchemaConversionValidator(state); + // Validator should always return PASS when it fails to get latest schema. + Assert.assertTrue(validator.validate(topic1)); + Assert.assertTrue(validator.validate(topic2)); + Assert.assertTrue(validator.validate(topic3)); + } + + // A KafkaSchemaRegistry class that returns the hardcoded schemas for the test topics. + public static class TestKafkaSchemaRegistry implements KafkaSchemaRegistry { + private final String schemaMaxInnerFieldDepthIs1 = "{" + + "\"type\": \"record\"," + + " \"name\": \"test\"," + + " \"fields\": [" + + " {\n" + + " \"name\": \"id\"," + + " \"type\": \"int\"" + + " }," + + " {" + + " \"name\": \"timestamp\"," + + " \"type\": \"string\"" + + " }" + + " ]" + + "}"; + + private final String schemaMaxInnerFieldDepthIs2 = "{" + + " \"type\": \"record\"," + + " \"name\": \"nested\"," + + " \"fields\": [" + + " {" + + " \"name\": \"nestedId\"," + + " \"type\": {\n" + + " \"type\": \"array\"," + + " \"items\": \"string\"" + + " }" + + " }," + + " {" + + " \"name\": \"timestamp\"," + + " \"type\": \"string\"" + + " }" + + " ]" + + "}"; + + private final String schemaWithRecursiveRef = "{" + + " \"type\": \"record\"," + + " \"name\": \"TreeNode\"," + + " \"fields\": [" + + " {" + + " \"name\": \"value\"," + + " \"type\": \"int\"" + + " }," + + " {" + + " \"name\": \"children\"," + + " \"type\": {" + + " \"type\": \"array\"," + + " \"items\": \"TreeNode\"" + + " }" + + " }" + + " ]" + + "}"; + private final Map topicToSchema; + + public TestKafkaSchemaRegistry(Properties props) { + topicToSchema = ImmutableMap.of( + "topic1", new Schema.Parser().parse(schemaMaxInnerFieldDepthIs1), + "topic2", new Schema.Parser().parse(schemaMaxInnerFieldDepthIs2), + "topic3", new Schema.Parser().parse(schemaWithRecursiveRef)); + } + @Override + public Schema getLatestSchema(String topicName) { + return topicToSchema.get(topicName); + } + + @Override + public MD5Digest register(String name, Schema schema) { + return null; + } + + @Override + public Schema getById(MD5Digest id) { + return null; + } + + @Override + public boolean hasInternalCache() { + return false; + } + } + + // A KafkaSchemaRegistry class that always fail to get latest schema. + public static class BadKafkaSchemaRegistry implements KafkaSchemaRegistry { + public BadKafkaSchemaRegistry(Properties props) { + } + + @Override + public Schema getLatestSchema(String name) throws IOException, SchemaRegistryException { + throw new SchemaRegistryException("Exception in getLatestSchema()"); + } + + @Override + public MD5Digest register(String name, Schema schema) throws IOException, SchemaRegistryException { + return null; + } + + @Override + public Schema getById(MD5Digest id) throws IOException, SchemaRegistryException { + return null; + } + + @Override + public boolean hasInternalCache() { + return false; + } + } +} diff --git a/gobblin-modules/gobblin-kafka-common/src/test/java/org/apache/gobblin/source/extractor/extract/kafka/validator/TopicValidatorsTest.java b/gobblin-modules/gobblin-kafka-common/src/test/java/org/apache/gobblin/source/extractor/extract/kafka/validator/TopicValidatorsTest.java index 2691ae112c..4390190ec4 100644 --- a/gobblin-modules/gobblin-kafka-common/src/test/java/org/apache/gobblin/source/extractor/extract/kafka/validator/TopicValidatorsTest.java +++ b/gobblin-modules/gobblin-kafka-common/src/test/java/org/apache/gobblin/source/extractor/extract/kafka/validator/TopicValidatorsTest.java @@ -67,6 +67,19 @@ public void testValidatorTimeout() { Assert.assertEquals(validTopics.get(0).getName(), "topic2"); } + @Test + public void testValidatorThrowingException() { + List allTopics = Arrays.asList("topic1", "topic2"); + List topics = buildKafkaTopics(allTopics); + State state = new State(); + state.setProp(TopicValidators.VALIDATOR_CLASSES_KEY, ValidatorThrowingException.class.getName()); + List validTopics = new TopicValidators(state).validate(topics); + + Assert.assertEquals(validTopics.size(), 2); // validator throws exceptions, so all topics are treated as valid + Assert.assertTrue(validTopics.stream().anyMatch(topic -> topic.getName().equals("topic1"))); + Assert.assertTrue(validTopics.stream().anyMatch(topic -> topic.getName().equals("topic2"))); + } + private List buildKafkaTopics(List topics) { return topics.stream() .map(topicName -> new KafkaTopic(topicName, Collections.emptyList())) @@ -109,4 +122,15 @@ public boolean validate(KafkaTopic topic) { return false; } } + + public static class ValidatorThrowingException extends TopicValidatorBase { + public ValidatorThrowingException(State state) { + super(state); + } + + @Override + public boolean validate(KafkaTopic topic) throws Exception { + throw new Exception("Always throw exception"); + } + } } diff --git a/gobblin-utility/src/main/java/org/apache/gobblin/util/orc/AvroOrcSchemaConverter.java b/gobblin-utility/src/main/java/org/apache/gobblin/util/orc/AvroOrcSchemaConverter.java index 9f227ca5d5..cc885a987f 100644 --- a/gobblin-utility/src/main/java/org/apache/gobblin/util/orc/AvroOrcSchemaConverter.java +++ b/gobblin-utility/src/main/java/org/apache/gobblin/util/orc/AvroOrcSchemaConverter.java @@ -27,7 +27,24 @@ * A utility class that provides a method to convert {@link Schema} into {@link TypeDescription}. */ public class AvroOrcSchemaConverter { + // Convert avro schema to orc schema, calling tryGetOrcSchema without recursive depth limit for backward compatibility public static TypeDescription getOrcSchema(Schema avroSchema) { + return tryGetOrcSchema(avroSchema, 0, Integer.MAX_VALUE - 1); + } + + /** + * Try converting the avro schema into {@link TypeDescription}, but with max recursive depth to avoid stack overflow. + * A typical use case is the topic validation during work unit creation. + * @param avroSchema The avro schema to convert + * @param currentDepth Current depth of the recursive call + * @param maxDepth Max depth of the recursive call + * @return the converted {@link TypeDescription} + */ + public static TypeDescription tryGetOrcSchema(Schema avroSchema, int currentDepth, int maxDepth) + throws StackOverflowError { + if (currentDepth == maxDepth + 1) { + throw new StackOverflowError("Recursive call of tryGetOrcSchema() reaches max depth " + maxDepth); + } final Schema.Type type = avroSchema.getType(); switch (type) { @@ -43,12 +60,12 @@ public static TypeDescription getOrcSchema(Schema avroSchema) { case FIXED: return getTypeDescriptionForBinarySchema(avroSchema); case ARRAY: - return TypeDescription.createList(getOrcSchema(avroSchema.getElementType())); + return TypeDescription.createList(tryGetOrcSchema(avroSchema.getElementType(), currentDepth + 1, maxDepth)); case RECORD: final TypeDescription recordStruct = TypeDescription.createStruct(); for (Schema.Field field2 : avroSchema.getFields()) { final Schema fieldSchema = field2.schema(); - final TypeDescription fieldType = getOrcSchema(fieldSchema); + final TypeDescription fieldType = tryGetOrcSchema(fieldSchema, currentDepth + 1, maxDepth); if (fieldType != null) { recordStruct.addField(field2.name(), fieldType); } else { @@ -59,19 +76,19 @@ public static TypeDescription getOrcSchema(Schema avroSchema) { case MAP: return TypeDescription.createMap( // in Avro maps, keys are always strings - TypeDescription.createString(), getOrcSchema(avroSchema.getValueType())); + TypeDescription.createString(), tryGetOrcSchema(avroSchema.getValueType(), currentDepth + 1, maxDepth)); case UNION: final List nonNullMembers = getNonNullMembersOfUnion(avroSchema); if (isNullableUnion(avroSchema, nonNullMembers)) { // a single non-null union member // this is how Avro represents "nullable" types; as a union of the NULL type with another // since ORC already supports nullability of all types, just use the child type directly - return getOrcSchema(nonNullMembers.get(0)); + return tryGetOrcSchema(nonNullMembers.get(0), currentDepth + 1, maxDepth); } else { // not a nullable union type; represent as an actual ORC union of them final TypeDescription union = TypeDescription.createUnion(); for (final Schema childSchema : nonNullMembers) { - union.addUnionChild(getOrcSchema(childSchema)); + union.addUnionChild(tryGetOrcSchema(childSchema, currentDepth + 1, maxDepth)); } return union; }