Skip to content

Commit

Permalink
added minor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 25, 2024
1 parent a89d4de commit d8db5fd
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Pytest fixtures.
"""
from __future__ import annotations

import pytest
import numpy as np


TEST_CAUSAL_LMS = [
"hf-internal-testing/tiny-random-gpt2",
# "hf-internal-testing/tiny-random-MistralForCausalLM",
]


@pytest.fixture(params=[42])
def random_seed(request) -> int:
return request.param


@pytest.fixture
def rng(random_seed: int) -> np.random.Generator:
return np.random.default_rng(random_seed)


@pytest.fixture(params=TEST_CAUSAL_LMS)
def causal_lm_name_or_path(request) -> str:
"""Name or path of the CausalLM used for testing."""
return request.param
15 changes: 15 additions & 0 deletions tests/test_load_huggingface_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Test functions in folktexts.llm_utils
"""

from transformers import PreTrainedModel, PreTrainedTokenizerBase
from folktexts.llm_utils import load_model_tokenizer


def test_load_model_tokenizer(causal_lm_name_or_path: str):
model, tokenizer = load_model_tokenizer(causal_lm_name_or_path)

assert isinstance(model, PreTrainedModel), \
f"Expected model type `PreTrainedModel`, got {type(model)}."

assert isinstance(tokenizer, PreTrainedTokenizerBase), \
f"Expected tokenizer type `PreTrainedTokenizer`, got {type(tokenizer)}."

0 comments on commit d8db5fd

Please sign in to comment.