본문 바로가기
논문/논문 리뷰

논문 리뷰 2. Stock Price Movement Prediction Using Sentiment Analysis and CandleStick Chart Representation()

by p-jiho 2023. 3. 8.

논문 명 : Stock Price Movement Prediction Using Sentiment Analysis and CandleStick Chart Representation. 2021

저자 : Trang-Thi Ho, Yennun Huang

 

이 전에 이 논문의 내용에 따라 5가지 모델로 감성분석을 했었다.

오늘은 이 논문에서 나온 Candlestick chart에 대한 분석을 해볼 것이다.

Candlestick chart를 그림으로 저장해 그 그림을 분석하였다.

 

필요한 패키지

import pandas as pd
import yfinance as yf
from sklearn.model_selection import train_test_split

# !pip install https://github.com/matplotlib/mpl_finance/archive/master.zip
from mpl_finance import candlestick_ohlc
import numpy as np
import matplotlib.pyplot as plt

from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Dense, Activation, Dropout, Flatten
from sklearn.model_selection import KFold, GridSearchCV
from keras.callbacks import ReduceLROnPlateau
from keras.wrappers.scikit_learn import KerasClassifier
from keras.callbacks import ModelCheckpoint
from keras.optimizers import Adam, SGD, RMSprop

from sklearn.preprocessing import MinMaxScaler
from keras.layers import Conv1D, Concatenate, BatchNormalization, GlobalAveragePooling1D
from keras import Model
from sklearn.preprocessing import OneHotEncoder

위에 from mpl_finance import cnadlestick_ohlc는 candlestick chart를 그리기 위한 패키지이다.

그 위에 이 패키지를 깔 수 있는 코드를 명시해두었다. 코드 그대로 설치하고 사용하면 된다.

 

price_data = yf.download("^DJI",start = '2011-12-31', end = '2022-05-01')
price_data = price_data[['Open', 'High', 'Low', 'Close', 'Volume']]
price_data = price_data.reset_index()

n = 4
price_data["Future_trend"] = price_data.Close - price_data.Close.shift(n)
price_data.Future_trend[0:(price_data.shape[0]-n)] = price_data.Future_trend[n:(price_data.shape[0])]
price_data = price_data.iloc[0:(price_data.shape[0]-n)]
price_data.Future_trend = price_data.Future_trend.apply(lambda x: 1 if x>0 else 0)
price_data.Date = price_data.Date.apply(lambda x: x.year*10000+x.month*100 + x.day)

나는 일단 다우존스와 애플의 주가를 예측해보았다. 다우존스는 ^DJI, 애플은 AAPL 을 넣어서 해당 기간에 맞는 데이터를 추출하면 된다.

추출한 데이터는 시가, 고가, 저가, 종가, 거래량을 사용한다.

n은 n day 후를 예측한다는 뜻이다. 즉, 1일의 시가, 고가, 저가, 종가, 거래량을 가지고 n일 후에 1일보다 오르는지 내리는 지를 예측하는 것이다.

이 내용에 맞게 데이터를 생성해준다.

shift 함수를 이용해서 원하는 만큼 위치를 옮겨준다. 그리고 오늘의 종가에서 빼주면 오늘보다 얼마나 오르고 내렸는지를 알 수 있다.

그림으로 설명을 해보았다.

1번째 그림에서 input과 output data가 있다. input에는 시가, 고가, 저가, 종가, 거래량이 있고 output에는 오르내림 즉, trend가 있다.

2번째 그림에서 n=4로 shift 하고 이를 종가에서 뺐으므로 (5일 종가 - 1일 종가)가 trend가 된다.

3번째 그림에서 1일의 데이터로 4일 후의 trend를 예측해야하므로 그에 맞게 데이터를 옮겨주고

4번째 그림과 같이 데이터 구조에 맞지않는 데이터는 삭제시켜준다.

이렇게 데이터를 구성하고 Date를 "yyyy-mm-dd"형식에서 "yyyymmdd" 형식으로 lambda함수를 이용해 변환시킨다.