{ "cells": [ { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "import torch.optim\n", "import pytorch_lightning as pl" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "class LitTrainer(pl.LightningModule):\n", " def __init__(self, model, loss_fn, optim):\n", " super().__init__()\n", " self.model = model\n", " self.loss_fn = loss_fn\n", " self.optim = optim\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " x = x.to(torch.float32)\n", "\n", " y_pred = self.model(x).reshape(1, -1)\n", " train_loss = self.loss_fn(y_pred, y)\n", "\n", " self.log(\"train_loss\", train_loss)\n", " return train_loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " # this is the validation loop\n", " x, y = batch\n", " x = x.to(torch.float32)\n", "\n", " y_pred = self.model(x).reshape(1, -1)\n", " validate_loss = self.loss_fn(y_pred, y)\n", "\n", " self.log(\"val_loss\", validate_loss)\n", "\n", " def test_step(self, batch, batch_idx):\n", " # this is the test loop\n", " x, y = batch\n", " x = x.to(torch.float32)\n", "\n", " y_pred = self.model(x).reshape(1, -1)\n", " test_loss = self.loss_fn(y_pred, y)\n", "\n", " self.log(\"test_loss\", test_loss)\n", "\n", " def configure_optimizers(self):\n", " return self.optim\n" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }