Background
Back in the dark ages (yes, I mean before Chat-GPT took the internet by storm), transformer-based text generation models were known only to AI researchers and the nerdiest of programmers. It took until the release of OpenAi’s GPT-2 in early 2019 before I joined the hype train. Like everyone else at the time, I wanted to co-opt the power of transformers for fun and profit, but alas OpenAi was only allowing other big-time researchers, or deep-pocketed corporations access to their groundbreaking technology. Thankfully in the meantime, Ben Wang and Aran Komatsuzaki were assembling a crack team of open-source gods to compete with OpenAI. Working at breakneck pace, they released GPT-J, a free and open competitor to GPT-2 and GPT-3 by early 2021, just in time for me to make use of their work in my senior research project for my college degree.
Timeline
My Research Project
Concept
In mid-2021, there was lots of excitement around the flexibility of large transformer models. I wanted to see if I could use GPT-J as an all-in-one news-based trading bot. The idea being, if a model has read enough of the internet as background, it should understand the context of a news article, and output a trading signal for a given company.
Input: News article about a company ⮕ Output: Buy/Sell/Neutral signal
Data gathering
For the best chance of success, I decided to fine-tune the GPT-J-6B model with data formatted like the prompts I intended to test with. I scraped and labeled 140,000 news articles from the below web sources:
To avoid rate limiting, I created a custom web scraper to use and manage a pool of rotating proxies. I also used a custom rate limiter to ensure that I was not making too many requests to any given domain. Running the scraper from an OracleVM, I was able to scrape around 40 articles per minute and complete all 140,000 articles in around 2.5 days.
class Scraper:
"""
Scrapes a list of seed urls and calls the parser function to process each result.
The parser function should accept two arguments: the response object and the url.
The parser may pass a list of new urls to scrape with the addUrls(urls) method.
The scraper attempts to never scrape the same url twice.
"""
def __init__(self, seedURLs = [], parser = None, options = None):
emptyLogFile()
self.options = {
'cookieDirectory': './cookies/',
'scrapeThreads': 10,
'rateLimits': {},
'maxAttempts': 3,
'proxyUpdateInterval': 10
}
if options:
self.options.update(options)
self._cookies = loadCookies(self.options['cookieDirectory'])
self._parser = parser
self._urlQueue = queue.PriorityQueue() # Priorty queue of urls to scrape. Priority is based on the number of times a url has requested.
self._finishedUrls = {} # Dictionary of urls that have been scraped or have been attempted the max number of times.
self._finishedUrlsLock = threading.Lock()
self.addUrls(seedURLs)
self._threadNumber = self.options['scrapeThreads']
# Current queue of proxy sessions. Periodically updated.
self._proxySessions = queue.Queue()
for proxy in self.verifyProxies(self.getProxies()):
self._proxySessions.put(proxy)
# Dictionary of domains and their rate limits and the last time a request was made.
self._rateLimits = {}
self._rateLimitsLock = threading.Lock()
for domain in self.options['rateLimits']:
self._rateLimits[domain] = {
'limit': self.options['rateLimits'][domain],
'lastRequest': {}
}
logging.debug("Initialized scraper with " + str(self._threadNumber) + " threads.")
def addUrls(self, urls):
""" Adds a list of urls to the list of urls to scrape. """
count = 0
with self._finishedUrlsLock:
for url in urls:
if url not in self._finishedUrls:
self._urlQueue.put((0, url))
count += 1
logging.debug("Added " + str(count) + " urls to the queue.")
def _scrape(self):
"""
Scraping thread worker.
Scrapes urls from the list of urls and calls the parser function.
"""
while True:
if self._urlQueue.empty(): break
attempts, url = self._urlQueue.get()
domain = tldextract.extract(url).domain + '.' + tldextract.extract(url).suffix
if self._proxySessions.empty(): continue
session = self._proxySessions.get()
with self._rateLimitsLock:
sessionKey = repr(session)
if domain in self._rateLimits:
if sessionKey in self._rateLimits[domain]['lastRequest']:
if max(time.time() - self._rateLimits[domain]['lastRequest'][sessionKey],0) <= self._rateLimits[domain]['limit']:
self._urlQueue.put((attempts, url))
continue
else:
self._rateLimits[domain]['lastRequest'][sessionKey] = time.time()
else:
self._rateLimits[domain]['lastRequest'][sessionKey] = 0
else:
self._rateLimits[domain] = {
'limit': 0,
'lastRequest': {}
}
try:
r = session.get(url, timeout=10)
if r.status_code == 200:
if self._parser: self._parser(r, url, domain, self.addUrls)
self._finishedUrls[url] = True
elif r.status_code == 429:
self._urlQueue.put((attempts, url))
with self._rateLimitsLock:
self._rateLimits[domain]['limit'] = int(r.headers.get('Retry-After')) or int(r.headers.get('x-retry-after')) or self._rateLimits[domain]['limit'] + 0.5
logging.warning("429 Occured with url " + url + " Rate limit for " + domain + " is now " + str(self._rateLimits[domain]['limit']) + " seconds.")
elif r.status_code == 404:
self._finishedUrls[url] = False
else:
raise Exception("Unexpected status code: " + str(r.status_code) + " for url: " + url)
except Exception as e:
attempts += 1
if attempts < self.options['maxAttempts']:
self._urlQueue.put((attempts, url))
else:
self._finishedUrls[url] = False
logging.warning("Error scraping " + url + " Attempts: " + str(attempts) + " with proxy " + str(session.proxies) + ": " + str(e))
self._proxySessions.put(session)
def getProxies(self):
"""Gets proxies from https://www.socks-proxy.net/"""
sourceURL5 = 'http://list.didsoft.com/get?email=joelskyler@gmail.com&pass=u3t3mm&pid=socks1100&showcountry=no&version=socks5'
r = requests.get(sourceURL5)
data5 = r.text.split('\n')
return [{'http': 'socks5://' + x, 'https': 'socks5://' + x} for x in data5 if x != '']
def _verifyProxy(self, proxy, url, timeout=5, verifiedProxies={}):
""" Verifies a proxy is workinng by requesting a url. """
try:
# Create new sesssion with proxy
session = requests.Session()
session.proxies = proxy
session.headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/74.0.3729.169 Safari/537.36',
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3',
'Accept-Encoding': 'gzip, deflate, br',
'Accept-Language': 'en-US,en;q=0.9',
'Connection': 'keep-alive',
}
session.cookies = self._cookies
r = session.get(url, timeout=timeout)
if r.status_code == 200:
verifiedProxies[repr(proxy)] = session
except Exception as e:
logging.warning("Error verifying proxy " + repr(proxy) + " for url " + url + "Exception" + str(e))
def verifyProxies(self, proxies):
""" Verifies a list of proxies. """
verifiedProxies = {}
proxyQueue = queue.Queue()
for proxy in proxies:
proxyQueue.put(proxy)
threadNum = len(proxies)
threads = []
for i in range(threadNum):
t = threading.Thread(target = self._verifyProxy, args = (proxyQueue.get(), 'https://www.ft.com/', 25, verifiedProxies))
threads.append(t)
t.start()
for t in threads:
result = t.join()
goodProxies = verifiedProxies.values()
logging.debug("Verified " + str(len(goodProxies)) + " proxies.")
print("Verified " + str(len(goodProxies)) + " proxies.")
return goodProxies
def _maintainProxies(self):
"""
Periodically updates the list of proxies.
"""
lastTimeUpdated = time.time()
while True:
if self._urlQueue.empty():
break
time.sleep(1)
if time.time() - lastTimeUpdated > 60 * self.options['proxyUpdateInterval']:
newProxies = self.verifyProxies(self.getProxies())
# Empty the queue of proxies.
while not self._proxySessions.empty():
self._proxySessions.get()
# Add the new proxies to the queue.
for proxy in newProxies:
self._proxySessions.put(proxy)
lastTimeUpdated = time.time()
def _progress(self):
"""
Periodically prints the progress of the scraping.
"""
while not self._urlQueue.empty():
queueSize = self._urlQueue.qsize()
with self._finishedUrlsLock:
finishedUrls = len(self._finishedUrls)
progressString = "Progress: " + str(finishedUrls) + " Remaining: " + str(queueSize)
print("\r"+" "*len(progressString)+"\r"+progressString, end="")
time.sleep(0.1)
def run(self):
""" Starts the scraping threads and the proxy maintenance thread. """
scrapingThreads = []
for i in range(self._threadNumber):
t = threading.Thread(target = self._scrape)
t.start()
scrapingThreads.append(t)
proxyMaintThread = threading.Thread(target = self._maintainProxies)
proxyMaintThread.start()
progressThread = threading.Thread(target = self._progress)
progressThread.start()
for t in scrapingThreads:
t.join()
proxyMaintThread.join()
progressThread.join()
successfulScrapes = len([v for v in self._finishedUrls.values() if v])
print(f"\nFinished scraping. {len(self._finishedUrls)} urls scraped. {successfulScrapes} successful scrapes.")
I then split my dataset into a training set with 90,000 articles and into validation/testing sets with 10,000/40,000 articles respectively.
Training
Because the model weights are over 60GB, specialized computing was required for this training task. The training was done with a Google TPU-v3 generously donated by Google’s TPU Research Cloud. I trained my model using the following hyperparameters:
{
"layers": 28,
"d_model": 4096,
"n_heads": 16,
"n_vocab": 50400.
"warmup_steps": 160,
"anneal_steps": 1530,
"lr": 1.2e-4,
"end_lr": 1.2e-5,
"weight_decay": 0.1,
"total_steps": 1700,
"tpu_size": 8,
"bucket": "peckham_tpu_europe",
"model_dir": "mesh_jax_stock_model_slim_f16",
"train_set": "stocks.train.index",
"val_set": {
"stocks": "stocks.val.index"
},
"val_batches": 5777,
"val_every": 500,
"ckpt_every": 500,
"keep_every": 10000
}
Training Loss
At first, I was encouraged by seeing a meaningful decrease in training loss, however, now I believe this was due to overfitting to the training set.
Results
Fine-tuned model:
50.58% (N=40K)
correct guesses of stock price movement direction.
Standard GPT-J model:
50.63% (N=40K)
correct guesses of stock price movement direction.
Human testing:
56% (N=120)
correct guesses of stock price movement direction.
Conclusion
In my testing, GPT-J performs no better than random when predicting the direction of price movement based on a news story. Furthermore, fine-tuning the model does not improve prediction accuracy.
I believe these results are due to the following factors:
- The model (except for a small amount of fine-tuning) is not trained on financial data.
- The data was not cleaned well enough to remove noise.
- Text data is not a good representation of financial data. In other words, the model doesn’t care about the actual numbers, only how well they “sound” in the context of the article.
- The articles may also not have enough information to predict price movement anyway. News is often considered a lagging indicator of price movement, and the articles I scraped were often published after the relevant price movement had already occurred.