KAN-GPT

PyPI - Downloads PyPI - Version codecov CI GitHub License

The PyTorch implementation of Generative Pre-trained Transformers (GPTs) using Kolmogorov-Arnold Networks (KANs) for language modeling

Install it from PyPI

pip install kan_gpt

Usage

Refer to the KAN_GPT.ipynb and kan_gpt/prompt.py for usage examples. The following is an ourtine of how to use the model:

from kan_gpt.model import GPT
from transformers import GPT2Tokenizer

model_config = GPT.get_default_config()
model_config.model_type = "gpt2"
model_config.vocab_size = 50257
model_config.block_size = 1024
model = GPT(model_config)

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

prompt = "Bangalore is often described as the "

prompt_encoded = tokenizer.encode(
  text=prompt, add_special_tokens=False
)

x = torch.tensor(prompt_encoded).unsqueeze(0)

model.eval()
y = model.generate(x, 50)  # sample 50 tokens

result = tokenizer.decode(y)

print(result)

# Bangalore is often described as the Silicon Valley of India.
# The city has witnessed rapid growth in the past two decades.....

Setup for Development

# Download Repo
git clone https://github.com/AdityaNG/kan-gpt
cd kan-gpt
git pull

# Download Dataset
./scripts/download_webtext.sh
./scripts/download_tinyshakespeare.sh

# Install dependencies for development
pip install -r requirements.txt
pip install -e .

Train

Use the following dummy script to make sure everything is working as expected

WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture MLP --batch_size 1 --dummy_dataset --device cpu --max_iters 200
WANDB_MODE=offline CUDA_VISIBLE_DEVICE="" python3 -m kan_gpt.train --architecture KAN --batch_size 1 --dummy_dataset --device cpu --max_iters 200

Then make use of the training script

python -m kan_gpt.train

Prompt

You can prompt the model to produce text as follows

python -m kan_gpt.prompt --prompt "Bangalore is often described as the " --model_path (checkpoint)

References