summaryrefslogtreecommitdiffstats
path: root/test/test_post_hooks.py
blob: 3778d1794258679e55a386913be919db343ec3c8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python3

# Allow direct execution
import os
import sys
import unittest

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


from test.helper import get_params, is_download_test, try_rm
import yt_dlp.YoutubeDL  # isort: split
from yt_dlp.utils import DownloadError


class YoutubeDL(yt_dlp.YoutubeDL):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.to_stderr = self.to_screen


TEST_ID = 'gr51aVj-mLg'
EXPECTED_NAME = 'gr51aVj-mLg'


@is_download_test
class TestPostHooks(unittest.TestCase):
    def setUp(self):
        self.stored_name_1 = None
        self.stored_name_2 = None
        self.params = get_params({
            'skip_download': False,
            'writeinfojson': False,
            'quiet': True,
            'verbose': False,
            'cachedir': False,
        })
        self.files = []

    def test_post_hooks(self):
        self.params['post_hooks'] = [self.hook_one, self.hook_two]
        ydl = YoutubeDL(self.params)
        ydl.download([TEST_ID])
        self.assertEqual(self.stored_name_1, EXPECTED_NAME, 'Not the expected name from hook 1')
        self.assertEqual(self.stored_name_2, EXPECTED_NAME, 'Not the expected name from hook 2')

    def test_post_hook_exception(self):
        self.params['post_hooks'] = [self.hook_three]
        ydl = YoutubeDL(self.params)
        self.assertRaises(DownloadError, ydl.download, [TEST_ID])

    def hook_one(self, filename):
        self.stored_name_1, _ = os.path.splitext(os.path.basename(filename))
        self.files.append(filename)

    def hook_two(self, filename):
        self.stored_name_2, _ = os.path.splitext(os.path.basename(filename))
        self.files.append(filename)

    def hook_three(self, filename):
        self.files.append(filename)
        raise Exception('Test exception for \'%s\'' % filename)

    def tearDown(self):
        for f in self.files:
            try_rm(f)


if __name__ == '__main__':
    unittest.main()