music-ai / model.py
Zeroxdesignart's picture
Create model.py
302cf9d verified
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from music21 import converter, instrument, note, chord, stream
import gradio as gr
# Load Models
melody_model = AutoModelForCausalLM.from_pretrained("your_melody_model") # Replace with actual model name
harmony_model = AutoModelForCausalLM.from_pretrained("your_harmony_model") # Replace with actual model name
rhythm_model = AutoModelForCausalLM.from_pretrained("your_rhythm_model") # Replace with actual model name
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Define functions for each step
def generate_melody(prompt, length=50):
"""Generates a melody sequence based on the given prompt."""
inputs = tokenizer(prompt, return_tensors="pt")
melody_output = melody_model.generate(inputs['input_ids'], max_length=length)
melody_notes = tokenizer.decode(melody_output[0], skip_special_tokens=True)
return melody_notes
def generate_harmony(melody_sequence, length=100):
"""Generates harmonic support based on the melody."""
harmony_input = torch.cat([tokenizer.encode(melody_sequence, return_tensors="pt"), tokenizer("add harmony", return_tensors="pt")['input_ids']], dim=1)
harmony_output = harmony_model.generate(harmony_input, max_length=length)
harmony_notes = tokenizer.decode(harmony_output[0], skip_special_tokens=True)
return harmony_notes
def generate_rhythm(harmony_sequence, length=50):
"""Adds rhythm to the harmony for structure."""
rhythm_input = torch.cat([tokenizer.encode(harmony_sequence, return_tensors="pt"), tokenizer("add rhythm", return_tensors="pt")['input_ids']], dim=1)
rhythm_output = rhythm_model.generate(rhythm_input, max_length=length)
rhythm_sequence = tokenizer.decode(rhythm_output[0], skip_special_tokens=True)
return rhythm_sequence
def create_midi(melody, harmony, rhythm):
"""Converts melody, harmony, and rhythm sequences to MIDI format."""
composition = stream.Stream()
for part in [melody, harmony, rhythm]:
for token in part.split():
if token.isdigit():
midi_note = note.Note(int(token))
midi_note.quarterLength = 0.5
composition.append(midi_note)
elif token == "rest":
rest_note = note.Rest()
rest_note.quarterLength = 0.5
composition.append(rest_note)
midi_fp = "generated_music.mid"
composition.write('midi', fp=midi_fp)
return midi_fp
# Full generation function
def generate_music(prompt, length=50):
melody = generate_melody(prompt, length)
harmony = generate_harmony(melody, length)
rhythm = generate_rhythm(harmony, length)
midi_file = create_midi(melody, harmony, rhythm)
return midi_file
# Set up Gradio interface
iface = gr.Interface(
fn=generate_music,
inputs=["text", "slider"],
outputs="file",
title="Multi-Model AI Music Generator",
description="Generate music using a multi-model AI system that combines melody, harmony, and rhythm layers.",
)
iface.launch()