Hello World (Classifying Names with a Character-Level RNN)¶
The get start example for a character-level RNN that classifies names based on their characters.
Setup¶
check and prepare the env
In [19]:
%pip install --upgrade pip
%pip install torch
%pip install unidecode
%pip install tensorboard
%pip install matplotlib
Looking in indexes: https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple Requirement already satisfied: pip in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (25.2) Note: you may need to restart the kernel to use updated packages. Looking in indexes: https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple Requirement already satisfied: torch in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (2.8.0) Requirement already satisfied: filelock in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from torch) (3.18.0) Requirement already satisfied: typing-extensions>=4.10.0 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from torch) (4.14.1) Requirement already satisfied: sympy>=1.13.3 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from torch) (1.14.0) Requirement already satisfied: networkx in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from torch) (3.4.2) Requirement already satisfied: jinja2 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from torch) (3.1.6) Requirement already satisfied: fsspec in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from torch) (2025.7.0) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from sympy>=1.13.3->torch) (1.3.0) Requirement already satisfied: MarkupSafe>=2.0 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from jinja2->torch) (3.0.2) Note: you may need to restart the kernel to use updated packages. Looking in indexes: https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple Requirement already satisfied: unidecode in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (1.4.0) Note: you may need to restart the kernel to use updated packages. Looking in indexes: https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple Requirement already satisfied: tensorboard in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (2.20.0) Requirement already satisfied: absl-py>=0.4 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from tensorboard) (2.3.1) Requirement already satisfied: grpcio>=1.48.2 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from tensorboard) (1.74.0) Requirement already satisfied: markdown>=2.6.8 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from tensorboard) (3.8.2) Requirement already satisfied: numpy>=1.12.0 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from tensorboard) (2.2.6) Requirement already satisfied: packaging in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from tensorboard) (25.0) Requirement already satisfied: pillow in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from tensorboard) (11.3.0) Requirement already satisfied: protobuf!=4.24.0,>=3.19.6 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from tensorboard) (6.31.1) Requirement already satisfied: setuptools>=41.0.0 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from tensorboard) (65.5.0) Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from tensorboard) (0.7.2) Requirement already satisfied: werkzeug>=1.0.1 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from tensorboard) (3.1.3) Requirement already satisfied: MarkupSafe>=2.1.1 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from werkzeug>=1.0.1->tensorboard) (3.0.2) Note: you may need to restart the kernel to use updated packages. Looking in indexes: https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple Collecting matplotlib Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/34/bc/ba802650e1c69650faed261a9df004af4c6f21759d7a1ec67fe972f093b3/matplotlib-3.10.5-cp310-cp310-macosx_11_0_arm64.whl (8.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.1/8.1 MB 669.2 kB/s 0:00:12659.4 kB/s eta 0:00:01 Collecting contourpy>=1.0.1 (from matplotlib) Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/2f/6c/330de89ae1087eb622bfca0177d32a7ece50c3ef07b28002de4757d9d875/contourpy-1.3.2-cp310-cp310-macosx_11_0_arm64.whl (253 kB) Collecting cycler>=0.10 (from matplotlib) Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl (8.3 kB) Collecting fonttools>=4.22.0 (from matplotlib) Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/1c/1f/3dcae710b7c4b56e79442b03db64f6c9f10c3348f7af40339dffcefb581e/fonttools-4.59.0-cp310-cp310-macosx_10_9_universal2.whl (2.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.8/2.8 MB 844.6 kB/s 0:00:031.0 MB/s eta 0:00:010m Collecting kiwisolver>=1.3.1 (from matplotlib) Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/41/4c/1925dcfff47a02d465121967b95151c82d11027d5ec5242771e580e731bd/kiwisolver-1.4.9-cp310-cp310-macosx_11_0_arm64.whl (65 kB) Requirement already satisfied: numpy>=1.23 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from matplotlib) (2.2.6) Requirement already satisfied: packaging>=20.0 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from matplotlib) (25.0) Requirement already satisfied: pillow>=8 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from matplotlib) (11.3.0) Collecting pyparsing>=2.3.1 (from matplotlib) Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl (111 kB) Requirement already satisfied: python-dateutil>=2.7 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from matplotlib) (2.9.0.post0) Requirement already satisfied: six>=1.5 in /Users/shelton/.pyenv/versions/3.10.16/envs/hacking/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib) (1.17.0) Installing collected packages: pyparsing, kiwisolver, fonttools, cycler, contourpy, matplotlib ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6/6 [matplotlib] 5/6 [matplotlib] Successfully installed contourpy-1.3.2 cycler-0.12.1 fonttools-4.59.0 kiwisolver-1.4.9 matplotlib-3.10.5 pyparsing-3.2.3 Note: you may need to restart the kernel to use updated packages.
In [4]:
import torch
device = torch.device('cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
torch.set_default_device(device)
print(f"Using device = {torch.get_default_device()}")
Using device = cpu
Prepare Dataset¶
In [5]:
from unidecode import unidecode
import torch.nn.functional as F
print(f"converting 'Ślusàrski' to {unidecode('Ślusàrski')}")
import string
allowed_chars = string.printable
print(f"Allowed characters: {allowed_chars}")
char_indexes = {c: i for i, c in enumerate(allowed_chars)}
print(f"Character indexes: {char_indexes}")
def one_hot_embedding(text):
indexes = torch.tensor([char_indexes[c] for c in unidecode(text)])
return F.one_hot(indexes, num_classes=len(allowed_chars)).float().unsqueeze(1)
print(f"One-hot embedding for 'Hello': {one_hot_embedding('Hello')}")
converting 'Ślusàrski' to Slusarski Allowed characters: 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ Character indexes: {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, 'a': 10, 'b': 11, 'c': 12, 'd': 13, 'e': 14, 'f': 15, 'g': 16, 'h': 17, 'i': 18, 'j': 19, 'k': 20, 'l': 21, 'm': 22, 'n': 23, 'o': 24, 'p': 25, 'q': 26, 'r': 27, 's': 28, 't': 29, 'u': 30, 'v': 31, 'w': 32, 'x': 33, 'y': 34, 'z': 35, 'A': 36, 'B': 37, 'C': 38, 'D': 39, 'E': 40, 'F': 41, 'G': 42, 'H': 43, 'I': 44, 'J': 45, 'K': 46, 'L': 47, 'M': 48, 'N': 49, 'O': 50, 'P': 51, 'Q': 52, 'R': 53, 'S': 54, 'T': 55, 'U': 56, 'V': 57, 'W': 58, 'X': 59, 'Y': 60, 'Z': 61, '!': 62, '"': 63, '#': 64, '$': 65, '%': 66, '&': 67, "'": 68, '(': 69, ')': 70, '*': 71, '+': 72, ',': 73, '-': 74, '.': 75, '/': 76, ':': 77, ';': 78, '<': 79, '=': 80, '>': 81, '?': 82, '@': 83, '[': 84, '\\': 85, ']': 86, '^': 87, '_': 88, '`': 89, '{': 90, '|': 91, '}': 92, '~': 93, ' ': 94, '\t': 95, '\n': 96, '\r': 97, '\x0b': 98, '\x0c': 99} One-hot embedding for 'Hello': tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])
In [7]:
from torch.utils.data import Dataset
import os
import torch
class NameDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.data = []
self.data_tensors = []
self.labels = []
label_set = set()
for filename in os.listdir(data_dir):
if filename.endswith('.txt'):
label = filename.split('.')[0]
label_set.add(label)
with open(os.path.join(data_dir, filename), 'r', encoding='utf-8') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
if line:
self.data.append(line)
self.data_tensors.append(one_hot_embedding(line))
self.labels.append(label)
self.uniq_labels = list(label_set)
self.label_tensors = [torch.tensor([self.uniq_labels.index(label)], dtype=torch.long) for label in self.labels]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
data_tensor = self.data_tensors[idx]
label = self.labels[idx]
label_tensor = self.label_tensors[idx]
return label_tensor, data_tensor, label, data
In [8]:
name_dataset = NameDataset('/Users/shelton/dev/github/hacking/data/data/names')
print(f"Dataset size: {len(name_dataset)}")
print(f"First item in dataset: {name_dataset[0]}")
Dataset size: 20074 First item in dataset: (tensor([11]), tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]), 'Czech', 'Abl')
In [9]:
import torch.nn as nn
import torch.nn.functional as F
class CharRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(CharRNN, self).__init__()
self.rnn = nn.RNN(input_size, hidden_size)
self.h2o = nn.Linear(hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, x):
rnn_out, hidden = self.rnn(x)
output = self.h2o(hidden[0])
output = self.softmax(output)
return output
In [10]:
n_hidden = 128
rnn = CharRNN(len(allowed_chars), n_hidden, len(name_dataset.uniq_labels))
print(rnn)
CharRNN( (rnn): RNN(100, 128) (h2o): Linear(in_features=128, out_features=18, bias=True) (softmax): LogSoftmax(dim=1) )
In [11]:
def get_label_from_output(output, labels):
top_n, top_i = output.topk(1)
label_i = top_i[0].item()
return labels[label_i], label_i
input = one_hot_embedding('Albert')
print(input.shape)
output = rnn(input) #this is equivalent to ``output = rnn.forward(input)``
print(output)
print(get_label_from_output(output, name_dataset.uniq_labels))
torch.Size([6, 1, 100]) tensor([[-2.7712, -2.8355, -2.9082, -2.9814, -2.9606, -2.7589, -2.9449, -2.9469, -2.8528, -2.7547, -2.9139, -2.9130, -2.8154, -3.0121, -2.9169, -2.9532, -2.9210, -2.9176]], grad_fn=<LogSoftmaxBackward0>) ('Korean', 9)
In [15]:
import torch
import tempfile
from torch.utils.tensorboard import SummaryWriter
log_dir = tmp_log_dir = tempfile.mkdtemp(prefix="tensorboard_logs_")
writer = SummaryWriter(log_dir)
In [16]:
import time
import random
import torch.nn as nn
import numpy as np
def train(rnn, training_data, n_epoch = 10, n_batch_size = 64, report_every = 50, learning_rate = 0.2, criterion = nn.NLLLoss()):
current_loss = 0
all_losses = []
rnn.train()
optimizer = torch.optim.SGD(rnn.parameters(), lr=learning_rate)
start = time.time()
print(f"training on data set with n = {len(training_data)}")
for iter in range(1, n_epoch + 1):
rnn.zero_grad()
batches = list(range(len(training_data)))
random.shuffle(batches)
batches = np.array_split(batches, len(batches) //n_batch_size )
for idx, batch in enumerate(batches):
batch_loss = 0
for i in batch: #for each example in this batch
(label_tensor, text_tensor, label, text) = training_data[i]
output = rnn.forward(text_tensor)
loss = criterion(output, label_tensor)
batch_loss += loss
batch_loss.backward()
nn.utils.clip_grad_norm_(rnn.parameters(), 3)
optimizer.step()
optimizer.zero_grad()
current_loss += batch_loss.item() / len(batch)
now_loss = current_loss / len(batches)
writer.add_scalar("Loss/train", now_loss, iter)
all_losses.append(now_loss)
if iter % report_every == 0:
print(f"{iter} ({iter / n_epoch:.0%}): \t average batch loss = {all_losses[-1]}")
current_loss = 0
return all_losses
In [17]:
os.system(f"tensorboard --logdir {log_dir} --host 0.0.0.0 --port 8081 &")
Out[17]:
0
TensorFlow installation not found - running with reduced feature set. TensorBoard 2.20.0 at http://0.0.0.0:8081/ (Press CTRL+C to quit)
In [18]:
import time
start = time.time()
all_losses = train(rnn, name_dataset, n_epoch=27, learning_rate=0.15, report_every=5)
writer.close()
end = time.time()
print(f"training took {end-start}s")
training on data set with n = 20074
W0811 14:34:08.116276 6137556992 application.py:559] path /debugger-frontend/rn_fusebox.html not found, sending 404
5 (19%): average batch loss = 0.8614327898275124 10 (37%): average batch loss = 0.6777308707815566 15 (56%): average batch loss = 0.5674129698142633 20 (74%): average batch loss = 0.494484276482531 25 (93%): average batch loss = 0.43785551377950005 training took 187.33912801742554s
In [20]:
import matplotlib.pyplot as plt
plt.figure()
plt.plot(all_losses)
plt.show()
Matplotlib is building the font cache; this may take a moment.
In [21]:
input = one_hot_embedding('Sun')
print(input.shape)
output = rnn(input)
print(output)
print(get_label_from_output(output, name_dataset.uniq_labels))
torch.Size([3, 1, 100]) tensor([[-13.0435, -7.4312, -11.4961, -17.3842, -6.8836, -9.0367, -8.3997, -12.6531, -12.5519, -1.5192, -7.1438, -8.8310, -0.4023, -11.6027, -3.4756, -4.6570, -13.3974, -2.6739]], grad_fn=<LogSoftmaxBackward0>) ('Chinese', 12)