This repository has been archived by the owner on Jun 14, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_song_indic_bart.py
executable file
·130 lines (115 loc) · 6.7 KB
/
generate_song_indic_bart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/python3
import sys
import os
import random
os.environ['PYTHONPATH'] = '/home/ubuntu/.local/lib/python3.10/site-packages'
print("sys.executable:", sys.executable)
print("sys.path:", sys.path)
print("Environment Variables:", os.environ)
print("PYTHONPATH:", os.environ.get('PYTHONPATH', 'Not Set'))
print("sys.path at runtime:", sys.path)
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from indicnlp.transliterate.unicode_transliterate import UnicodeIndicTransliterator
import torch
# Load the tokenizer and model from the local files
tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicBART", do_lower_case=False, use_fast=False, keep_accents=True)
model = AutoModelForSeq2SeqLM.from_pretrained("ai4bharat/IndicBART")
# Load the model weights from the pytorch_model.bin file
model_weights_path = "./model_output/pytorch_model.bin"
model.load_state_dict(torch.load(model_weights_path, map_location=torch.device('cpu')), strict=False)
# Define a list of prompts for generating songs in Telugu
prompts = [
"ఈ పాట గురించి ప్రేమ, బాధ, ఆనందం, మరియు జీవితం గురించి ఒక పూర్తి పాట రాయండి. పాట ప్రారంభం:",
"ఈ పాట గురించి స్నేహం, ఆశ, మరియు కలలు గురించి ఒక పూర్తి పాట రాయండి. పాట ప్రారంభం:",
"ఈ పాట గురించి కుటుంబం, ఆనందం, మరియు ఆశయం గురించి ఒక పూర్తి పాట రాయండి. పాట ప్రారంభం:",
"ఈ పాట గురించి ప్రకృతి, సౌందర్యం, మరియు ప్రశాంతత గురించి ఒక పూర్తి పాట రాయండి. పాట ప్రారంభం:",
"ఈ పాట గురించి విజయాలు, సవాళ్లు, మరియు ప్రేరణ గురించి ఒక పూర్తి పాట రాయండి. పాట ప్రారంభం:"
]
# Randomly select a prompt from the list
prompt = random.choice(prompts)
# Use the prompt directly in Telugu script
prompt_telugu = f"{prompt} </s> <2te>"
print("Prompt in Telugu:", prompt_telugu)
# Tokenize the input prompt in Devanagari script
inputs = tokenizer(prompt_telugu, return_tensors="pt")
print("Tokenized input IDs:", inputs.input_ids)
print("Token ID range:", inputs.input_ids.min().item(), "-", inputs.input_ids.max().item())
print("Model embedding layer configuration:", model.get_input_embeddings())
# Generate text using the model
bos_token_id = tokenizer._convert_token_to_id_with_added_voc("<s>")
eos_token_id = tokenizer._convert_token_to_id_with_added_voc("</s>")
decoder_start_token_id = tokenizer._convert_token_to_id_with_added_voc("<2te>") # Use Telugu token for generation
outputs = model.generate(
inputs.input_ids,
max_length=1500, # Increase max_length to allow for longer sequences
num_beams=5,
early_stopping=True,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
forced_bos_token_id=decoder_start_token_id,
no_repeat_ngram_size=3,
num_return_sequences=5,
repetition_penalty=1.5, # Add repetition penalty to discourage repeated phrases
length_penalty=1.0, # Add length penalty to encourage longer sequences
temperature=0.9, # Adjust temperature to balance diversity and coherence
top_k=40, # Adjust top-k sampling to limit the number of highest probability tokens to keep for generation
top_p=0.9, # Adjust top-p (nucleus) sampling to keep the smallest set of tokens with cumulative probability >= top_p
do_sample=True # Enable sampling to allow temperature to take effect
)
print("Generated output IDs:", outputs)
# Decode the generated text
generated_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
print("Generated texts:", generated_texts)
# Filter out non-Telugu characters and placeholders from the generated text
# Filter out non-Telugu characters and placeholders from the generated text
filtered_texts_telugu = []
for text in generated_texts:
filtered_text = ''.join([char for char in text if 0x0C00 <= ord(char) <= 0x0C7F or char in [' ', '.', ',', '!', '?', ':', ';', '-', '(', ')', '[', ']', '{', '}', '"', "'", '\n', '\t']])
# Remove placeholders and incomplete words
filtered_text = filtered_text.replace('...', '').replace('?', '').replace('!', '').replace(':', '').replace(';', '').replace('-', '').replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace('{', '').replace('}', '').replace('"', '').replace("'", '')
# Ensure the text is not empty and has a minimum length
if len(filtered_text.strip()) > 10:
filtered_texts_telugu.append(filtered_text.strip())
print("Filtered texts in Telugu:", filtered_texts_telugu)
# Implement a content filter to check for inappropriate content
prohibited_words = ["గైంగరేప్", "అశ్లీల", "అమానవీయ"]
final_texts_telugu = []
for text in filtered_texts_telugu:
if not any(word in text for word in prohibited_words):
final_texts_telugu.append(text)
else:
# Regenerate text if inappropriate content is found
new_outputs = model.generate(
inputs.input_ids,
max_length=1500,
num_beams=5,
early_stopping=True,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
forced_bos_token_id=decoder_start_token_id,
no_repeat_ngram_size=3,
num_return_sequences=1,
repetition_penalty=1.5,
length_penalty=1.0,
temperature=1.0,
top_k=50,
top_p=0.9,
do_sample=True
)
new_generated_text = tokenizer.decode(new_outputs[0], skip_special_tokens=True)
new_filtered_text = ''.join([char for char in new_generated_text if 0x0C00 <= ord(char) <= 0x0C7F or char in [' ', '.', ',', '!', '?', ':', ';', '-', '(', ')', '[', ']', '{', '}', '"', "'", '\n', '\t']])
if len(new_filtered_text) > 10:
final_texts_telugu.append(new_filtered_text)
# Print the generated songs
print("Generated Songs in Telugu:")
for idx, song in enumerate(final_texts_telugu):
print(f"Song {idx + 1}:")
print(song)
print()
# Save the generated songs to a file
with open("generated_song_demo.txt", "w", encoding="utf-8") as f:
for idx, song in enumerate(final_texts_telugu):
f.write(f"Song {idx + 1}:\n")
f.write(song + "\n\n")