first commit
This commit is contained in:
parent
35c076f20b
commit
695c49cb7b
10 changed files with 425 additions and 33 deletions
|
@ -7,4 +7,3 @@ Version 0.1
|
||||||
|
|
||||||
- Feature A added
|
- Feature A added
|
||||||
- FIX: nasty bug #1729 fixed
|
- FIX: nasty bug #1729 fixed
|
||||||
- add your changes here!
|
|
||||||
|
|
163
README.rst
163
README.rst
|
@ -1,19 +1,160 @@
|
||||||
============
|
# GPUs Monitor
|
||||||
gpus_monitor
|
|
||||||
============
|
|
||||||
|
|
||||||
|
|
||||||
Add a short description here!
|
|
||||||
|
`gpus_monitor` is a Python GPUs activities monitoring tool designed to report by email new and recently died compute processes over the machine where it has been run on.
|
||||||
|
Basically, when you have just run a new stable training on the machine where `gpus_monitor` listen to, you will received in a few seconds an email notification. This email will contains several informations about the process you've launched.
|
||||||
|
You received also an email if a compute process died (with EXIT_STATUS = 0 or not).
|
||||||
|
|
||||||
|
|
||||||
Description
|
### Kind of mail gpus_monitor is going to send you :
|
||||||
===========
|
|
||||||
|
1. New training detected :
|
||||||
|
|
||||||
|
>> From: <gpusstatus@mydomain.com>
|
||||||
|
>> Subject : 1 processes running on <MACHINE_NAME> (LOCAL_IP_OF_THE_MACHINE)
|
||||||
|
>>
|
||||||
|
>> New events (triggered on the 08/11/2020 11:45:52):
|
||||||
|
>>
|
||||||
|
>> ---------------------------------------------------------------------------------------------------------------
|
||||||
|
>> A new process (PID : 12350) has been launched on GPU 0 (Quadro RTX 4000) by <owner_of_the_process> since 08/11/2020 11:43:48
|
||||||
|
>> His owner (<owner_of_the_process>) has executed the following command :
|
||||||
|
>> python3 test_torch.py
|
||||||
|
>> From :
|
||||||
|
>> <absolute_path_to_the_script_of_the_new_launched_process>
|
||||||
|
>>
|
||||||
|
>> CPU Status (currently):
|
||||||
|
>> For this process : 19/40 logic cores (47.5%)
|
||||||
|
>>
|
||||||
|
>> GPU Status (currently):
|
||||||
|
>> - Used memory (for this process): 879 / 7979.1875 MiB (11.02 % used)
|
||||||
|
>> - Used memory (for all processes running on this GPU) 7935.3125 / 7979.1875 MiB (99.45 % used)
|
||||||
|
>> - Temperature : 83 Celsius
|
||||||
|
>> - Driver version : 435.21
|
||||||
|
>> ---------------------------------------------------------------------------------------------------------------
|
||||||
|
>>
|
||||||
|
>>
|
||||||
|
>>
|
||||||
|
>> This message has been automatically send by a robot. Please don't answer to this mail
|
||||||
|
>> Please, feel free to open a merge request on github.com/araison12/gpus_monitor if you have encountered a bug or to share your ideas to improve this tool
|
||||||
|
|
||||||
|
|
||||||
|
2. Training died (either finished well or not)
|
||||||
|
>> From: <gpusstatus@mydomain.com>
|
||||||
|
>> Subject : 1 processes running on <MACHINE_NAME> (LOCAL_IP_OF_THE_MACHINE)
|
||||||
|
>>
|
||||||
|
>> New events (triggered on the 08/11/2020 11:47:29):
|
||||||
|
>>
|
||||||
|
>> ---------------------------------------------------------------------------------------------------------------
|
||||||
|
>> The process (PID : 12350) launched by araison since 08/11/2020 11:43:48 has ended.
|
||||||
|
>> His owner araison had executed the following command :
|
||||||
|
>> python3 test_torch.py
|
||||||
|
>> From :
|
||||||
|
>> <absolute_path_to_the_script_of_the_died_process>
|
||||||
|
>>
|
||||||
|
>> The process took 0:03:41 to finish.
|
||||||
|
>> --------------------------------------------------------------------------------------------------------------
|
||||||
|
>>
|
||||||
|
>> This message has been automatically send by a robot. Please don't answer to this mail
|
||||||
|
>> Please, feel free to open a merge request on github.com/araison12/gpus_monitor if you have encountered a bug or to share your ideas to improve this tool
|
||||||
|
|
||||||
|
|
||||||
|
## Instructions to use gpus_monitor :
|
||||||
|
|
||||||
|
|
||||||
|
1. Cloning this repository :
|
||||||
|
|
||||||
|
`git clone https://github.com/araison12/gpus_monitor.git`
|
||||||
|
|
||||||
|
2. Installing dependencies :
|
||||||
|
|
||||||
|
`pip3 install -r gpus_monitor/requirements.txt`
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
`python3 setup.py install --user`
|
||||||
|
|
||||||
|
3. Add peoples mail to the list of the `persons_to_inform.yaml` file :
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
list:
|
||||||
|
- adrien.raison@univ-poitiers.fr
|
||||||
|
- other_person_to_inform@hisdomain.com
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Note : You can hot-add/remove mails in this file without the need of killing the scanning process !
|
||||||
|
|
||||||
|
4. Add SMTP Server parameters (server adress, credentials, port number, etc..)
|
||||||
|
|
||||||
|
You can manage these stuff in the `gpus_monitor/src/gpus_monitor/config.py` file :
|
||||||
|
To adjust these varibales you have to edit the `gpus_monitor/src/gpus_monitor/config.py` file.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd gpus_monitor/src/gpus_monitor/
|
||||||
|
vim config.py
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
For privacy purposes, login of my dedicated SMTP account are stored in 2 machine environment variables. I've set up a brandnew Gmail account for my `gpus_monitor` instance. I can share with you my credentials in order to use a single SMTP account for `gpus_monitor` instance on several machines, feel free to send me an email !
|
||||||
|
Otherwise, fill in with your own SMTP server configuration.
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
USER = os.environ.get(
|
||||||
|
"GPUSMONITOR_MAIL_USER"
|
||||||
|
)
|
||||||
|
|
||||||
|
PASSWORD = os.environ.get(
|
||||||
|
"GPUSMONITOR_MAIL_PASSWORD"
|
||||||
|
)
|
||||||
|
PORT = 465
|
||||||
|
SMTP_SERVER = "smtp.gmail.com"
|
||||||
|
```
|
||||||
|
|
||||||
|
See https://askubuntu.com/a/58828 to handle efficiently (permanent adding) environment variables.
|
||||||
|
|
||||||
|
5. Adjust the scanning rate of `gpus_monitor` and the processes age that he has to take in account.
|
||||||
|
|
||||||
|
|
||||||
|
The `WAITING_TIME` variable adjusts the scan timing rate of gpus_monitor.
|
||||||
|
|
||||||
|
```python
|
||||||
|
WAITING_TIME = 0.5 # min
|
||||||
|
```
|
||||||
|
|
||||||
|
The `PROCESS_AGE` variable adjusts the processes age that gpus_monitor has to take in account.
|
||||||
|
|
||||||
|
```python
|
||||||
|
PROCESS_AGE = 2 # min (gpus_monitor only consider >=2min aged processes)
|
||||||
|
```
|
||||||
|
|
||||||
|
6. Executing `gpus_monitor` when machine starts up.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
crontab -e
|
||||||
|
```
|
||||||
|
Add the following line to the brandnew opened file :
|
||||||
|
|
||||||
|
```bash
|
||||||
|
@reboot python3 /path/to/gpu_monitor/src/gpus_monitor/main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
# Ideas to enhanced the project :
|
||||||
|
|
||||||
|
- Log system (owner, total calculation time by user)
|
||||||
|
- Manage cases in email sending (subject): processes finished well or not (Send Traceback)
|
||||||
|
- Centralized system that scan every machine on a given IP adresses range.
|
||||||
|
- Better errors management (SMTP connection failed, no Cuda GPU on the machine,..)
|
||||||
|
- Documenting the project
|
||||||
|
- Rewrite it in oriented object fashion
|
||||||
|
|
||||||
|
|
||||||
|
If you have any ideas to improve this project, don't hesitate to make a merge request ! :)
|
||||||
|
|
||||||
A longer description of your project goes here...
|
|
||||||
|
|
||||||
|
|
||||||
Note
|
|
||||||
====
|
|
||||||
|
|
||||||
This project has been set up using PyScaffold 3.2.3. For details and usage
|
|
||||||
information on PyScaffold see https://pyscaffold.org/.
|
|
||||||
|
|
5
persons_to_inform.yaml
Normal file
5
persons_to_inform.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
---
|
||||||
|
|
||||||
|
list:
|
||||||
|
- adrien.raison@univ-poitiers.fr
|
||||||
|
# - <user@mydomain.com> Add your mail here and uncomment this line
|
|
@ -1,17 +1,4 @@
|
||||||
# =============================================================================
|
yaml
|
||||||
# DEPRECATION WARNING:
|
psutil
|
||||||
#
|
netifaces
|
||||||
# The file `requirements.txt` does not influence the package dependencies and
|
pynvml
|
||||||
# will not be automatically created in the next version of PyScaffold (v4.x).
|
|
||||||
#
|
|
||||||
# Please have look at the docs for better alternatives
|
|
||||||
# (`Dependency Management` section).
|
|
||||||
# =============================================================================
|
|
||||||
#
|
|
||||||
# Add your pinned requirements so that they can be easily installed with:
|
|
||||||
# pip install -r requirements.txt
|
|
||||||
# Remember to also add them in setup.cfg but unpinned.
|
|
||||||
# Example:
|
|
||||||
# numpy==1.13.3
|
|
||||||
# scipy==1.0
|
|
||||||
#
|
|
|
@ -4,15 +4,15 @@
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
name = gpus_monitor
|
name = gpus_monitor
|
||||||
description = Add a short description here!
|
description = GPUs Monitoring tool for machine learning purposes
|
||||||
author = araison
|
author = araison
|
||||||
author-email = adrien.raison@univ-poitiers.fr
|
author-email = adrien.raison@univ-poitiers.fr
|
||||||
license = proprietary
|
license = proprietary
|
||||||
long-description = file: README.rst
|
long-description = file: README.rst
|
||||||
long-description-content-type = text/x-rst; charset=UTF-8
|
long-description-content-type = text/x-rst; charset=UTF-8
|
||||||
url = https://github.com/pyscaffold/pyscaffold/
|
url = https://github.com/araison12/gpus_monitor
|
||||||
project-urls =
|
project-urls =
|
||||||
Documentation = https://pyscaffold.org/
|
# Documentation =
|
||||||
# Change if running only on Windows, Mac or Linux (comma-separated)
|
# Change if running only on Windows, Mac or Linux (comma-separated)
|
||||||
platforms = any
|
platforms = any
|
||||||
# Add here all kinds of additional classifiers as defined under
|
# Add here all kinds of additional classifiers as defined under
|
||||||
|
@ -30,7 +30,7 @@ package_dir =
|
||||||
# DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD!
|
# DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD!
|
||||||
setup_requires = pyscaffold>=3.2a0,<3.3a0
|
setup_requires = pyscaffold>=3.2a0,<3.3a0
|
||||||
# Add here dependencies of your project (semicolon/line-separated), e.g.
|
# Add here dependencies of your project (semicolon/line-separated), e.g.
|
||||||
# install_requires = numpy; scipy
|
install_requires = yaml;psutil;netifaces;pynvml
|
||||||
# The usage of test_requires is discouraged, see `Dependency Management` docs
|
# The usage of test_requires is discouraged, see `Dependency Management` docs
|
||||||
# tests_require = pytest; pytest-cov
|
# tests_require = pytest; pytest-cov
|
||||||
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
|
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
|
||||||
|
|
3
src/gpus_monitor/.vscode/settings.json
vendored
Normal file
3
src/gpus_monitor/.vscode/settings.json
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"python.pythonPath": "/usr/bin/python3"
|
||||||
|
}
|
25
src/gpus_monitor/config.py
Normal file
25
src/gpus_monitor/config.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
# SMTP SERVER CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
USER = os.environ.get(
|
||||||
|
"GPUSMONITOR_MAIL_USER"
|
||||||
|
) # for the moment, only accessible from SIC08005 (all users)
|
||||||
|
PASSWORD = os.environ.get(
|
||||||
|
"GPUSMONITOR_MAIL_PASSWORD"
|
||||||
|
) # for the moment, only accessible from SIC08005 (all users)
|
||||||
|
PORT = 465
|
||||||
|
SMTP_SERVER = "smtp.gmail.com"
|
||||||
|
|
||||||
|
# SCANNING RATE
|
||||||
|
|
||||||
|
WAITING_TIME = 0.5 # min
|
||||||
|
PROCESS_AGE = 2 # min
|
||||||
|
|
||||||
|
# PERSONS TO INFORM
|
||||||
|
def persons_to_inform():
|
||||||
|
with open("../../persons_to_inform.yaml", "r") as yaml_file:
|
||||||
|
PERSON_TO_INFORM_LIST = yaml.load(yaml_file, Loader=yaml.FullLoader)["list"]
|
||||||
|
return PERSON_TO_INFORM_LIST
|
118
src/gpus_monitor/main.py
Normal file
118
src/gpus_monitor/main.py
Normal file
|
@ -0,0 +1,118 @@
|
||||||
|
import time
|
||||||
|
import tools
|
||||||
|
import psutil
|
||||||
|
import config
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
machine_infos = tools.get_machine_infos()
|
||||||
|
processes_to_monitor_pid = []
|
||||||
|
processes_to_monitor = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
time.sleep(config.WAITING_TIME * 60)
|
||||||
|
current_time = time.time()
|
||||||
|
gpus_info = tools.gpus_snap_info()
|
||||||
|
driver_version = gpus_info["driver_version"]
|
||||||
|
news = ""
|
||||||
|
send_info = False
|
||||||
|
new_processes_count = 0
|
||||||
|
died_processes_count = 0
|
||||||
|
for index, gpu in enumerate(gpus_info["gpu"]):
|
||||||
|
gpu_name = gpu["product_name"]
|
||||||
|
processes = gpu["processes"]
|
||||||
|
if processes == "N/A":
|
||||||
|
news += f"Nothing is running on GPU {index} ({gpu_name})\n"
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
for p in processes:
|
||||||
|
pid = p["pid"]
|
||||||
|
p_info = tools.process_info(pid)
|
||||||
|
if (
|
||||||
|
current_time - p_info["since"] >= config.PROCESS_AGE * 60
|
||||||
|
and not pid in processes_to_monitor_pid
|
||||||
|
):
|
||||||
|
send_info = True
|
||||||
|
new_processes_count += 1
|
||||||
|
news += f"""
|
||||||
|
|
||||||
|
---------------------------------------------------------------------------------------------------------------
|
||||||
|
A new process (PID : {pid}) has been launched on GPU {index} ({gpu_name}) by {p_info['owner']} since {datetime.datetime.fromtimestamp(int(p_info['since'])).strftime("%d/%m/%Y %H:%M:%S")}
|
||||||
|
His owner ({p_info['owner']}) has executed the following command :
|
||||||
|
{' '.join(p_info['executed_cmd'])}
|
||||||
|
From :
|
||||||
|
{p_info['from']}
|
||||||
|
|
||||||
|
CPU Status (currently):
|
||||||
|
For this process : {p_info['cpu_core_required']}
|
||||||
|
|
||||||
|
GPU Status (currently):
|
||||||
|
- Used memory (for this process): {p['used_memory']} / {gpu['fb_memory_usage']['total']} {gpu['fb_memory_usage']['unit']} ({round(p['used_memory']/gpu['fb_memory_usage']['total']*100,2)} % used)
|
||||||
|
- Used memory (for all processes running on this GPU) {gpu['fb_memory_usage']['used']} / {gpu['fb_memory_usage']['total']} {gpu['fb_memory_usage']['unit']} ({round(gpu['fb_memory_usage']['used']/gpu['fb_memory_usage']['total']*100,2)} % used)
|
||||||
|
- Temperature : {gpu["temperature"]["gpu_temp"]} Celsius
|
||||||
|
- Driver version : {driver_version}
|
||||||
|
---------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
processes_to_monitor.append(p_info)
|
||||||
|
processes_to_monitor_pid.append(pid)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for p in processes_to_monitor[:]:
|
||||||
|
pid = p["pid"]
|
||||||
|
try:
|
||||||
|
still_running_p = psutil.Process(pid)
|
||||||
|
continue
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
send_info = True
|
||||||
|
died_processes_count += 1
|
||||||
|
news += f"""
|
||||||
|
|
||||||
|
---------------------------------------------------------------------------------------------------------------
|
||||||
|
The process (PID : {pid}) launched by {p['owner']} since {datetime.datetime.fromtimestamp(int(p['since'])).strftime("%d/%m/%Y %H:%M:%S")} has ended.
|
||||||
|
His owner {p_info['owner']} had executed the following command :
|
||||||
|
{' '.join(p['executed_cmd'])}
|
||||||
|
From :
|
||||||
|
{p['from']}
|
||||||
|
|
||||||
|
The process took {datetime.timedelta(seconds=int(current_time)-int(p['since']))} to finish.
|
||||||
|
---------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
processes_to_monitor.remove(p)
|
||||||
|
processes_to_monitor_pid.remove(pid)
|
||||||
|
|
||||||
|
subject = None
|
||||||
|
|
||||||
|
if new_processes_count > 0:
|
||||||
|
subject = f"{new_processes_count} processes running on {machine_infos['MACHINE_NAME']} ({machine_infos['LOCAL_IP']})"
|
||||||
|
elif died_processes_count > 0:
|
||||||
|
subject = f"{died_processes_count} processes died on {machine_infos['MACHINE_NAME']} ({machine_infos['LOCAL_IP']})"
|
||||||
|
else:
|
||||||
|
subject = "Error"
|
||||||
|
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
|
||||||
|
global_message = f"""
|
||||||
|
|
||||||
|
New events (triggered on the {dt_string}):
|
||||||
|
{news}
|
||||||
|
|
||||||
|
This message has been automatically send by a robot. Please don't answer to this mail.
|
||||||
|
|
||||||
|
Please, feel free to open a merge request on github.com/araison12/gpus_monitor if you have encountered a bug or to share your ideas to improve this tool :)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if send_info:
|
||||||
|
for person in config.persons_to_inform():
|
||||||
|
tools.send_mail(subject, global_message, person)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
65
src/gpus_monitor/tools.py
Normal file
65
src/gpus_monitor/tools.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
import os
|
||||||
|
import ssl
|
||||||
|
import time
|
||||||
|
import config
|
||||||
|
import psutil
|
||||||
|
import smtplib
|
||||||
|
import netifaces as ni
|
||||||
|
from socket import gaierror
|
||||||
|
from pynvml.smi import nvidia_smi
|
||||||
|
from email.message import EmailMessage
|
||||||
|
|
||||||
|
|
||||||
|
def send_mail(subject, message, receiver):
|
||||||
|
|
||||||
|
context = ssl.create_default_context()
|
||||||
|
msg = EmailMessage()
|
||||||
|
msg.set_content(message)
|
||||||
|
msg["Subject"] = subject
|
||||||
|
msg["From"] = config.USER
|
||||||
|
msg["To"] = receiver
|
||||||
|
|
||||||
|
with smtplib.SMTP_SSL(config.SMTP_SERVER, config.PORT, context=context) as server:
|
||||||
|
server.login(config.USER, config.PASSWORD)
|
||||||
|
server.send_message(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def get_machine_local_ip():
|
||||||
|
interfaces = ni.interfaces()
|
||||||
|
for inter in interfaces:
|
||||||
|
if ni.ifaddresses(inter)[ni.AF_INET][0]["addr"][:7] == "194.167":
|
||||||
|
return ni.ifaddresses(inter)[ni.AF_INET][0]["addr"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_machine_infos():
|
||||||
|
infos_uname = os.uname()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"OS_TYPE": infos_uname[0],
|
||||||
|
"MACHINE_NAME": infos_uname[1],
|
||||||
|
"LOCAL_IP": get_machine_local_ip(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def gpus_snap_info():
|
||||||
|
nvsmi = nvidia_smi.getInstance()
|
||||||
|
return nvsmi.DeviceQuery(
|
||||||
|
"memory.free,memory.total,memory.used,compute-apps,temperature.gpu,driver_version,timestamp,name"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def process_info(pid):
|
||||||
|
try:
|
||||||
|
process = psutil.Process(pid)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"pid": pid,
|
||||||
|
"owner": process.username(),
|
||||||
|
"executed_cmd": process.cmdline(),
|
||||||
|
"from": process.cwd(),
|
||||||
|
"since": process.create_time(),
|
||||||
|
"is_running": process.is_running(),
|
||||||
|
"cpu_core_required": f"{process.cpu_num()}/{os.cpu_count()} logic cores ({process.cpu_num()*100/os.cpu_count()}%)",
|
||||||
|
}
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
return None
|
49
test_torch.py
Normal file
49
test_torch.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
import torch
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
EPOCHS_TO_TRAIN = 50000000000
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print(f'The current used device for this testing training is {device}. Please make sure that at least one Cuda device is used instead for this testing training.')
|
||||||
|
assert False
|
||||||
|
|
||||||
|
class Net(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.fc1 = nn.Linear(2, 3, True)
|
||||||
|
self.fc2 = nn.Linear(3, 1, True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.sigmoid(self.fc1(x))
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
net = Net().to(device)
|
||||||
|
print(device)
|
||||||
|
inputs = list(
|
||||||
|
map(
|
||||||
|
lambda s: Variable(torch.Tensor([s]).to(device)),
|
||||||
|
[[0, 0], [0, 1], [1, 0], [1, 1]],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
targets = list(
|
||||||
|
map(lambda s: Variable(torch.Tensor([s]).to(device)), [[0], [1], [1], [0]])
|
||||||
|
)
|
||||||
|
|
||||||
|
criterion = nn.MSELoss()
|
||||||
|
optimizer = optim.SGD(net.parameters(), lr=0.01)
|
||||||
|
|
||||||
|
print("Training loop:")
|
||||||
|
for idx in range(0, EPOCHS_TO_TRAIN):
|
||||||
|
for input, target in zip(inputs, targets):
|
||||||
|
optimizer.zero_grad() # zero the gradient buffers
|
||||||
|
output = net(input)
|
||||||
|
loss = criterion(output, target)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step() # Does the update
|
||||||
|
if idx % 5000 == 0:
|
||||||
|
print(loss.item())
|
Loading…
Add table
Reference in a new issue