Skip to content

Commit

Permalink
Feat(Analysis): 취약점 분석 알고리즘 개선 (#75)
Browse files Browse the repository at this point in the history
* Fix(Problem): 북마크 조회 시 카운트 쿼리와 내용 쿼리가 서로 다른 현상 수정

* Feat(Analysis): 취약점 분석 알고리즘 개선

- 기존에 N일 경과 시 1/(N+1)의 가중치를 부과하던 것에서, 경과일과 최근에 푼 순서에 따라 단기/중기/장기로 나눠 가중치 부과

* Test(Analysis): 변경된 알고리즘에 따른 테스트 코드 수정

* Feat(Analysis): ProblemSolvingAnalysisType 관련 코드를 직관적으로 이해하기 쉽도록 변경 및 주석 추가

- maxPeriod를 maxPeriodDay로 변경
- getLongestAnalysisType을 고정된 값으로 반환하도록 변경

* Feat(Analysis): 취약벡터 계산 시 이해하기 쉽도록 주석을 추가

* Test(Analysis): 경계값 테스트에 더 적합한 수치로 변경
  • Loading branch information
morenow98 authored Sep 23, 2024
1 parent 31d16e5 commit b03fd49
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 76 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package com.jabiseo.analysis.domain;

import com.jabiseo.analysis.exception.AnalysisBusinessException;
import com.jabiseo.analysis.exception.AnalysisErrorCode;
import lombok.Getter;

import java.util.Arrays;
import java.util.Comparator;

@Getter
public enum ProblemSolvingAnalysisType {

// maxPeriodDay가 커지면 maxCount도 증가해야 함
// 단기 분석 타입: 최근 14일 이내에 최대 200개의 문제 풀이에 가중치 0.5 적용
SHORT_TERM( 0.5, 14, 200),
// 중기 분석 타입: 최근 30일 이내에 최대 300개의 문제 풀이에 가중치 0.3 적용
MID_TERM( 0.3, 30, 300),
// 단기 분석 타입: 최근 90일 이내에 최대 500개의 문제 풀이에 가중치 0.2 적용
LONG_TERM( 0.2, 90, 500);

final double weight;
final int maxPeriodDay;
final int maxCount;

ProblemSolvingAnalysisType(double weight, int maxPeriodDay, int maxCount) {
this.weight = weight;
this.maxPeriodDay = maxPeriodDay;
this.maxCount = maxCount;
}

public static ProblemSolvingAnalysisType getLongestPeriodAnalysisType() {
return LONG_TERM;
}

// 푼 기간과 최근 푼 순서가 주어지면 그에 맞는 ProblemSolvingAnalysisType을 반환한다.
public static ProblemSolvingAnalysisType fromPeriodAndCount(int period, int sequence) {
return Arrays.stream(ProblemSolvingAnalysisType.values())
.sorted(Comparator.comparingInt(ProblemSolvingAnalysisType::getMaxPeriodDay))
.filter(type -> period <= type.getMaxPeriodDay() && sequence <= type.getMaxCount())
.findFirst()
.orElseThrow(() -> new AnalysisBusinessException(AnalysisErrorCode.CANNOT_ANALYSE_PROBLEM_SOLVING));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public enum AnalysisErrorCode implements ErrorCode {
CANNOT_CALCULATE_VULNERABILITY("취약점 분석을 할 수 없습니다.", "ANA_001", ErrorCode.INTERNAL_SERVER_ERROR),
CANNOT_FIND_VECTOR("벡터를 찾을 수 없습니다.", "ANA_002", ErrorCode.INTERNAL_SERVER_ERROR),
NOT_ENOUGH_SOLVED_PROBLEMS("문제를 충분히 풀지 않아서 분석할 수 없습니다.", "ANA_003", ErrorCode.BAD_REQUEST),
CANNOT_ANALYSE_PROBLEM_SOLVING("분석할 수 없는 문제 풀이 기록입니다.", "ANA_004", ErrorCode.INTERNAL_SERVER_ERROR),
;

private final String message;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.jabiseo.analysis.service;

import com.jabiseo.analysis.domain.ProblemSolvingAnalysisType;
import com.jabiseo.analysis.dto.VulnerableSubjectDto;
import com.jabiseo.analysis.dto.VulnerableTagDto;
import com.jabiseo.analysis.exception.AnalysisBusinessException;
Expand All @@ -9,14 +10,17 @@
import com.jabiseo.learning.domain.ProblemSolvingRepository;
import com.jabiseo.member.domain.Member;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.Pageable;
import org.springframework.stereotype.Service;

import java.time.LocalDateTime;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static java.time.temporal.ChronoUnit.DAYS;
import static java.time.LocalDateTime.now;

@Service
@RequiredArgsConstructor
Expand Down Expand Up @@ -51,52 +55,61 @@ public List<Long> findVulnerableProblems(Member member, Certificate certificate)
return vulnerabilityProvider.findVulnerableProblems(vulnerableVector, certificate.getId(), DEFAULT_VULNERABLE_PROBLEM_COUNT);
}

private List<Float> findVulnerableVector(Member member, Certificate certificate) {
LocalDateTime oneYearAgo = LocalDateTime.now().minusYears(YEARS_OF_ANALYSIS);
// TODO: 어떤 쿼리가 더 효율적인지 테스트 필요
List<ProblemSolving> problemSolvings = problemSolvingRepository.findByMemberAndCertificateAndCreatedAtAfterWithLearning(member, certificate, oneYearAgo);
List<Float> findVulnerableVector(Member member, Certificate certificate) {

if (problemSolvings.isEmpty()) {
ProblemSolvingAnalysisType longestPeriodAnalysisType = ProblemSolvingAnalysisType.getLongestPeriodAnalysisType();
LocalDateTime fromDate = now().minusDays(longestPeriodAnalysisType.getMaxPeriodDay());
Pageable pageable = Pageable.ofSize(longestPeriodAnalysisType.getMaxCount());

List<ProblemSolving> longestTermProblemSolvings = problemSolvingRepository.findWithLearningByCreatedAtAfterOrderByCreatedAtDesc(member, certificate, fromDate, pageable);

if (longestTermProblemSolvings.isEmpty()) {
throw new AnalysisBusinessException(AnalysisErrorCode.NOT_ENOUGH_SOLVED_PROBLEMS);
}

List<Long> distinctProblemIds = problemSolvings.stream()
List<Long> distinctProblemIds = longestTermProblemSolvings.stream()
.map(problemSolving -> problemSolving.getProblem().getId())
.distinct()
.toList();

Map<Long, List<Float>> problemIdToVector = vulnerabilityProvider.findVectorsOfProblems(distinctProblemIds, certificate.getId());
return calculateVulnerableVector(problemSolvings, problemIdToVector);
Map<Long, Double> problemIdToWeight = calculateWeightsOfProblems(longestTermProblemSolvings);
return calculateVulnerableVector(distinctProblemIds, problemIdToVector, problemIdToWeight);
}

// 테스트를 위해 package-private로 변경
List<Float> calculateVulnerableVector(List<ProblemSolving> problemSolvings, Map<Long, List<Float>> problemIdToVector) {
// 풀었던 문제들의 벡터를 가중치를 곱하여 더한 후 반환
return problemSolvings.stream()
.map(problemSolving -> {
List<Float> problemVector = problemIdToVector.get(problemSolving.getProblem().getId());
double weight = calculateWeight(problemSolving.getLearning().getCreatedAt(), problemSolving.isCorrect());
return problemVector.stream()
Map<Long, Double> calculateWeightsOfProblems(List<ProblemSolving> problemSolvings) {
return IntStream.range(0, problemSolvings.size())
.boxed()
.collect(Collectors.toMap(
i -> problemSolvings.get(i).getProblem().getId(),
i -> calculateWeight(problemSolvings.get(i), i),
Double::sum
));
}

List<Float> calculateVulnerableVector(List<Long> distinctProblemIds, Map<Long, List<Float>> problemIdToVector, Map<Long, Double> problemIdToWeight) {
return distinctProblemIds.stream()
// 문제 ID별로 벡터의 각 요소에 가중치를 곱한다.
.map(problemId -> {
List<Float> vector = problemIdToVector.get(problemId);
double weight = problemIdToWeight.get(problemId);
return vector.stream()
.map(value -> (float) (value * weight))
.toList();
})
.reduce((vector1, vector2) ->
IntStream.range(0, vector1.size())
.mapToObj(i -> vector1.get(i) + vector2.get(i))
.toList()
)
.orElseThrow(() -> new AnalysisBusinessException(AnalysisErrorCode.CANNOT_CALCULATE_VULNERABILITY));
// 문제 ID별로 계산된 벡터를 최종 합연산한다.
.reduce((vector1, vector2) -> IntStream.range(0, vector1.size())
.mapToObj(i -> vector1.get(i) + vector2.get(i))
.collect(Collectors.toList()))
.orElseThrow(() -> new AnalysisBusinessException(AnalysisErrorCode.NOT_ENOUGH_SOLVED_PROBLEMS));
}

private double calculateWeight(LocalDateTime createdAt, boolean isCorrect) {
long daysBetween = createdAt.until(LocalDateTime.now(), DAYS);
// 시간 차이에 반비례한 가중치 계산. N일 차이가 날 경우 1/(N+1)의 가중치를 부여한다.
// 맞은 문제는 가중치에 -1을 곱한다.
if (isCorrect) {
return -1.0 / (daysBetween + 1);
} else {
return 1.0 / (daysBetween + 1);
}
// problemSolving과 그 문제를 최근에 접한 순서를 주면 가중치를 반환한다. sequence는 최근 푼 문제 순서
double calculateWeight(ProblemSolving problemSolving, int sequence) {
int daysAgo = (int) ChronoUnit.DAYS.between(problemSolving.getLearning().getCreatedAt(), now());
ProblemSolvingAnalysisType analysisType = ProblemSolvingAnalysisType.fromPeriodAndCount(daysAgo, sequence);
// 문제를 맞혔을 경우 -1을 곱해 가중치를 계산한다.
return problemSolving.isCorrect() ? -analysisType.getWeight() : analysisType.getWeight();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,26 @@

import com.jabiseo.certificate.domain.Certificate;
import com.jabiseo.member.domain.Member;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;

import java.time.LocalDateTime;
import java.util.List;

public interface ProblemSolvingRepository extends JpaRepository<ProblemSolving, Long> {

@Query("select ps from ProblemSolving ps join fetch ps.learning l where ps.member = :member and l.certificate = :learning_certificate and l.createdAt > :createdAt")
List<ProblemSolving> findByMemberAndCertificateAndCreatedAtAfterWithLearning(Member member, Certificate learning_certificate, LocalDateTime createdAt);
@Query("SELECT ps " +
"FROM ProblemSolving ps " +
"JOIN FETCH ps.learning l " +
"WHERE ps.member = :member AND l.certificate = :certificate AND l.createdAt > :fromDate " +
"ORDER BY l.createdAt DESC")
List<ProblemSolving> findWithLearningByCreatedAtAfterOrderByCreatedAtDesc(
@Param("member") Member member,
@Param("certificate") Certificate certificate,
@Param("fromDate") LocalDateTime fromDate,
Pageable pageable
);

}
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ public Page<ProblemWithBookmarkSummaryQueryDto> findBookmarkedSummaryByExamIdAnd
.join(problem.exam, exam)
.join(problem.subject, subject)
.join(bookmark).on(bookmark.problem.id.eq(problem.id))
.where(examIdEq(examId), subjectIdsIn(subjectIds));
.where(memberIdEq(memberId), examIdEq(examId), subjectIdsIn(subjectIds));

return PageableExecutionUtils.getPage(content, pageable, countQuery::fetchOne);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.jabiseo.analysis.domain;

import com.jabiseo.analysis.exception.AnalysisBusinessException;
import com.jabiseo.analysis.exception.AnalysisErrorCode;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.junit.jupiter.MockitoExtension;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertEquals;

@DisplayName("ProblemSolvingAnalysisType 테스트")
@ExtendWith(MockitoExtension.class)
class ProblemSolvingAnalysisTypeTest {

@DisplayName("기간과 문제 수에 따른 분석 타입을 반환한다.")
@ParameterizedTest
@CsvSource({
"0, 0, SHORT_TERM",
"13, 199, SHORT_TERM",
"14, 200, SHORT_TERM",
"30, 300, MID_TERM",
"90, 500, LONG_TERM"
})
void fromPeriodAndCount(int period, int count, ProblemSolvingAnalysisType expected) {
// when
ProblemSolvingAnalysisType actual = ProblemSolvingAnalysisType.fromPeriodAndCount(period, count);

// then
assertEquals(expected, actual);
}

@DisplayName("기간과 문제 수가 올바르지 않을 경우 예외처리한다.")
@ParameterizedTest
@CsvSource({
"91, 100",
"89, 501",
"91, 501"
})
void fromPeriodAndCountWithDefault(int period, int count) {
// when & then
assertThatThrownBy(() -> ProblemSolvingAnalysisType.fromPeriodAndCount(period, count))
.isInstanceOf(AnalysisBusinessException.class)
.hasFieldOrPropertyWithValue("errorCode", AnalysisErrorCode.CANNOT_ANALYSE_PROBLEM_SOLVING);
}
}
Loading

0 comments on commit b03fd49

Please sign in to comment.