{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from transformers import GPT2LMHeadModel, GPT2Tokenizer\n", "import torch\n", "DEVICE = torch.device(\"cuda:0\")\n", "\n", "model_name_or_path = \"sberbank-ai/rugpt3small_based_on_gpt2\"\n", "tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)\n", "model = GPT2LMHeadModel.from_pretrained(model_name_or_path).to(DEVICE)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "with open('anekdoty.txt', 'r', encoding='utf-8') as file:\n", " text = file.read()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/polyakovk/venv_linux/lib/python3.11/site-packages/transformers/data/datasets/language_modeling.py:53: FutureWarning: This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets library. You can have a look at this example script for pointers: https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py\n", " warnings.warn(\n" ] } ], "source": [ "from transformers import TextDataset, DataCollatorForLanguageModeling\n", "\n", "# Сохраним обучающие данные в .txt файл \n", "train_path = 'train_dataset.txt'\n", "with open(train_path, \"w\") as f:\n", " f.write(text)\n", "\n", "# Создание датасета\n", "train_dataset = TextDataset(tokenizer=tokenizer,file_path=train_path,block_size=32)\n", " \n", "# Создание даталодера (нарезает текст на оптимальные по длине куски)\n", "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from transformers import Trainer, TrainingArguments\n", "\n", "training_args = TrainingArguments(\n", " output_dir=\"./finetuned\",\n", " overwrite_output_dir=True,\n", " num_train_epochs=30,\n", " per_device_train_batch_size=32,\n", " per_device_eval_batch_size=16,\n", " warmup_steps=10,\n", " gradient_accumulation_steps=32,\n", " )\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " data_collator=data_collator,\n", " train_dataset=train_dataset,\n", " optimizers = (torch.optim.AdamW(model.parameters(),lr=0.001),None)\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [240/240 1:14:57, Epoch 27/30]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=240, training_loss=0.9343911488850911, metrics={'train_runtime': 4515.8084, 'train_samples_per_second': 58.428, 'train_steps_per_second': 0.053, 'total_flos': 4011240960000000.0, 'train_loss': 0.9343911488850911, 'epoch': 27.927272727272726})" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "model_path = \"finetuned\"\n", "tokenizer = GPT2Tokenizer.from_pretrained(model_path)\n", "model = GPT2LMHeadModel.from_pretrained(model_path).to(DEVICE)" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [], "source": [ "def generate_jokes(prompt, temperature, top_p, max_length, num_return_sequences):\n", " input_ids = tokenizer.encode(prompt, return_tensors='pt').to(DEVICE)\n", " \n", " # Генерируем несколько шуток\n", " outputs = model.generate(\n", " input_ids=input_ids,\n", " do_sample=True,\n", " # num_beams=5,\n", " temperature=temperature,\n", " top_p=top_p,\n", " max_length=max_length,\n", " num_return_sequences=num_return_sequences\n", " )\n", " \n", " # Обработка всех сгенерированных шуток\n", " jokes = []\n", " for output in outputs:\n", " generated_text = tokenizer.decode(output, skip_special_tokens=True)\n", " # Обрезаем текст после первой точки\n", " if '…' in generated_text:\n", " generated_text = generated_text.split('…')[0] + '.'\n", " elif '.' in generated_text:\n", " generated_text = generated_text.split('.')[0] + '.'\n", " elif '!' in generated_text:\n", " generated_text = generated_text.split('!')[0] + '.'\n", " jokes.append(generated_text)\n", " \n", " return jokes" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Шла Саша по шоссе, громко разговаривая с шофером.', 'Шла Саша по шоссе, громко матерясь и упирая руку в ширинку.', 'Шла Саша по шоссе, несла пургу и, как раз, дождь.', 'Шла Саша по шоссе, но не за трактором.']\n" ] } ], "source": [ "text = \"Шла Саша по шоссе\"\n", "print(generate_jokes(text, 1, 0.9, 30, 4))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "однажды я проваливал экзамен по истории.\n", "— Вино с возрастом становится лучше. Я становлюсь лучше с вином…\n", "— Сними\n" ] } ], "source": [ "text = \"однажды я пришел из школы\"\n", "input_ids = tokenizer.encode(text, return_tensors=\"pt\").to(DEVICE)\n", "model.eval()\n", "with torch.no_grad():\n", " out = model.generate(input_ids, \n", " do_sample=True,\n", " num_beams=2,\n", " temperature=1.5,\n", " top_p=0.9,\n", " max_length=30,\n", " \n", " )\n", "\n", "generated_text = list(map(tokenizer.decode, out))[0]\n", "print()\n", "print(generated_text)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# model.save_pretrained('./finetuned')\n", "# tokenizer.save_pretrained('./finetuned')" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "# import requests\n", "# from bs4 import BeautifulSoup\n", "# import re\n", "\n", "# # Функция для получения шуток с одной страницы\n", "# def get_jokes_from_page(url):\n", "# response = requests.get(url, headers=headers)\n", "# response.raise_for_status() # Проверка на ошибки запроса\n", "\n", "# soup = BeautifulSoup(response.text, 'html.parser')\n", "\n", "# # Находим все анекдоты на странице\n", "# jokes = soup.find_all('div', class_='anekdot-text') # Замените селектор на правильный\n", "\n", "# page_jokes = []\n", "# for joke in jokes:\n", "# # Извлекаем текст анекдота\n", "# joke_text = joke.get_text(strip=True)\n", " \n", "# # Удаляем цифры и символы в конце текста\n", "# joke_text_cleaned = re.sub(r'\\d+[\\#\\d]*$', '', joke_text).strip()\n", " \n", "# # Добавляем очищенный текст в список\n", "# page_jokes.append(joke_text_cleaned)\n", " \n", "# return page_jokes\n", "\n", "# # URL-шаблон для страниц\n", "# base_url = \"https://anekdotovstreet.com/korotkie-anekdoty/{}/\"\n", "\n", "# # Заголовки для имитации браузера\n", "# headers = {\n", "# 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'\n", "# }\n", "\n", "# # Открываем файл для записи анекдотов\n", "# with open('anekdoty.txt', 'w', encoding='utf-8') as file:\n", "# for page_number in range(2, 400):\n", "# # Формируем URL для текущей страницы\n", "# url = base_url.format(page_number)\n", "# print(f\"Собираю шутки со страницы {page_number}...\")\n", "\n", "# # Получаем шутки с текущей страницы\n", "# jokes = get_jokes_from_page(url)\n", " \n", "# # Если шуток нет, значит, страницы закончились (опционально)\n", "# if not jokes:\n", "# print(f\"Шутки на странице {page_number} не найдены.\")\n", "# continue\n", " \n", "# # Записываем шутки в файл\n", "# for joke in jokes:\n", "# file.write(joke + '\\n')\n", "\n", "# print(\"Анекдоты успешно сохранены в файл 'anekdoty.txt'.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "venv_linux", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 2 }