Skip to content

Commit

Permalink
use a tinymodel to test generation config which aviod timeout (#34482)
Browse files Browse the repository at this point in the history
* use a tinymodel to test generation config which aviod timeout

* remove tailing whitespace
  • Loading branch information
techkang authored Oct 29, 2024
1 parent 63ca6d9 commit 655bec2
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,15 +1544,16 @@ def test_pretrained_low_mem_new_config(self):
self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)

def test_generation_config_is_loaded_with_model(self):
# Note: `TinyLlama/TinyLlama-1.1B-Chat-v1.0` has a `generation_config.json` containing `max_length: 2048`
# Note: `hf-internal-testing/tiny-random-MistralForCausalLM` has a `generation_config.json`
# containing `bos_token_id: 1`

# 1. Load without further parameters
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
self.assertEqual(model.generation_config.max_length, 2048)
model = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL)
self.assertEqual(model.generation_config.bos_token_id, 1)

# 2. Load with `device_map`
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", device_map="auto")
self.assertEqual(model.generation_config.max_length, 2048)
model = AutoModelForCausalLM.from_pretrained(TINY_MISTRAL, device_map="auto")
self.assertEqual(model.generation_config.bos_token_id, 1)

@require_safetensors
def test_safetensors_torch_from_torch(self):
Expand Down

0 comments on commit 655bec2

Please sign in to comment.
  NODES
chat 3
COMMUNITY 1
INTERN 1
Note 3
Project 3
USERS 1