ricardo-lsantos commited on
Commit
b760fd0
1 Parent(s): f4ed0cd

Commented torch_directml

Browse files
AI/question_answering.py CHANGED
@@ -2,7 +2,7 @@
2
  # Creation date: 2024-01-10
3
 
4
  import torch
5
- import torch_directml
6
  from transformers import pipeline
7
 
8
  def getDevice(DEVICE):
@@ -13,9 +13,9 @@ def getDevice(DEVICE):
13
  elif DEVICE == "cuda":
14
  device = torch.device("cuda")
15
  dtype = torch.float16
16
- elif DEVICE == "directml":
17
- device = torch_directml.device()
18
- dtype = torch.float16
19
  return device
20
 
21
  def loadGenerator(device):
@@ -33,5 +33,5 @@ def clearCache(DEVICE, generator):
33
  generator.tokenizer.save_pretrained("cache")
34
  generator.model.save_pretrained("cache")
35
  del generator
36
- if DEVICE == "directml":
37
- torch_directml.empty_cache()
 
2
  # Creation date: 2024-01-10
3
 
4
  import torch
5
+ # import torch_directml
6
  from transformers import pipeline
7
 
8
  def getDevice(DEVICE):
 
13
  elif DEVICE == "cuda":
14
  device = torch.device("cuda")
15
  dtype = torch.float16
16
+ # elif DEVICE == "directml":
17
+ # device = torch_directml.device()
18
+ # dtype = torch.float16
19
  return device
20
 
21
  def loadGenerator(device):
 
33
  generator.tokenizer.save_pretrained("cache")
34
  generator.model.save_pretrained("cache")
35
  del generator
36
+ # if DEVICE == "directml":
37
+ # torch_directml.empty_cache()
AI/sentiment_analysis.py CHANGED
@@ -2,7 +2,7 @@
2
  # Creation date: 2024-01-10
3
 
4
  import torch
5
- import torch_directml
6
  from transformers import pipeline
7
 
8
  def getDevice(DEVICE):
@@ -13,9 +13,9 @@ def getDevice(DEVICE):
13
  elif DEVICE == "cuda":
14
  device = torch.device("cuda")
15
  dtype = torch.float16
16
- elif DEVICE == "directml":
17
- device = torch_directml.device()
18
- dtype = torch.float16
19
  return device
20
 
21
  def loadClassifier(device):
@@ -30,5 +30,5 @@ def clearCache(DEVICE, classifier):
30
  classifier.tokenizer.save_pretrained("cache")
31
  classifier.model.save_pretrained("cache")
32
  del classifier
33
- if DEVICE == "directml":
34
- torch_directml.empty_cache()
 
2
  # Creation date: 2024-01-10
3
 
4
  import torch
5
+ # import torch_directml
6
  from transformers import pipeline
7
 
8
  def getDevice(DEVICE):
 
13
  elif DEVICE == "cuda":
14
  device = torch.device("cuda")
15
  dtype = torch.float16
16
+ # elif DEVICE == "directml":
17
+ # device = torch_directml.device()
18
+ # dtype = torch.float16
19
  return device
20
 
21
  def loadClassifier(device):
 
30
  classifier.tokenizer.save_pretrained("cache")
31
  classifier.model.save_pretrained("cache")
32
  del classifier
33
+ # if DEVICE == "directml":
34
+ # torch_directml.empty_cache()
AI/summarization.py CHANGED
@@ -2,7 +2,7 @@
2
  # Creation date: 2024-01-10
3
 
4
  import torch
5
- import torch_directml
6
  from transformers import pipeline
7
 
8
  def getDevice(DEVICE):
@@ -13,9 +13,9 @@ def getDevice(DEVICE):
13
  elif DEVICE == "cuda":
14
  device = torch.device("cuda")
15
  dtype = torch.float16
16
- elif DEVICE == "directml":
17
- device = torch_directml.device()
18
- dtype = torch.float16
19
  return device
20
 
21
  def loadSummarizer(device):
@@ -30,5 +30,5 @@ def clearCache(DEVICE, summarizer):
30
  summarizer.tokenizer.save_pretrained("cache")
31
  summarizer.model.save_pretrained("cache")
32
  del summarizer
33
- if DEVICE == "directml":
34
- torch_directml.empty_cache()
 
2
  # Creation date: 2024-01-10
3
 
4
  import torch
5
+ # import torch_directml
6
  from transformers import pipeline
7
 
8
  def getDevice(DEVICE):
 
13
  elif DEVICE == "cuda":
14
  device = torch.device("cuda")
15
  dtype = torch.float16
16
+ # elif DEVICE == "directml":
17
+ # device = torch_directml.device()
18
+ # dtype = torch.float16
19
  return device
20
 
21
  def loadSummarizer(device):
 
30
  summarizer.tokenizer.save_pretrained("cache")
31
  summarizer.model.save_pretrained("cache")
32
  del summarizer
33
+ # if DEVICE == "directml":
34
+ # torch_directml.empty_cache()
AI/text_generation.py CHANGED
@@ -2,7 +2,7 @@
2
  # Creation date: 2024-01-10
3
 
4
  import torch
5
- import torch_directml
6
  from transformers import pipeline
7
 
8
  def getDevice(DEVICE):
@@ -13,9 +13,9 @@ def getDevice(DEVICE):
13
  elif DEVICE == "cuda":
14
  device = torch.device("cuda")
15
  dtype = torch.float16
16
- elif DEVICE == "directml":
17
- device = torch_directml.device()
18
- dtype = torch.float16
19
  return device
20
 
21
  def loadGenerator(device):
@@ -30,5 +30,5 @@ def clearCache(DEVICE, generator):
30
  generator.tokenizer.save_pretrained("cache")
31
  generator.model.save_pretrained("cache")
32
  del generator
33
- if DEVICE == "directml":
34
- torch_directml.empty_cache()
 
2
  # Creation date: 2024-01-10
3
 
4
  import torch
5
+ # import torch_directml
6
  from transformers import pipeline
7
 
8
  def getDevice(DEVICE):
 
13
  elif DEVICE == "cuda":
14
  device = torch.device("cuda")
15
  dtype = torch.float16
16
+ # elif DEVICE == "directml":
17
+ # device = torch_directml.device()
18
+ # dtype = torch.float16
19
  return device
20
 
21
  def loadGenerator(device):
 
30
  generator.tokenizer.save_pretrained("cache")
31
  generator.model.save_pretrained("cache")
32
  del generator
33
+ # if DEVICE == "directml":
34
+ # torch_directml.empty_cache()
AI/zero_shot_classification.py CHANGED
@@ -2,7 +2,7 @@
2
  # Creation date: 2024-01-10
3
 
4
  import torch
5
- import torch_directml
6
  from transformers import pipeline
7
 
8
  def getDevice(DEVICE):
@@ -13,9 +13,9 @@ def getDevice(DEVICE):
13
  elif DEVICE == "cuda":
14
  device = torch.device("cuda")
15
  dtype = torch.float16
16
- elif DEVICE == "directml":
17
- device = torch_directml.device()
18
- dtype = torch.float16
19
  return device
20
 
21
  def loadGenerator(device):
@@ -30,5 +30,5 @@ def clearCache(DEVICE, generator):
30
  generator.tokenizer.save_pretrained("cache")
31
  generator.model.save_pretrained("cache")
32
  del generator
33
- if DEVICE == "directml":
34
- torch_directml.empty_cache()
 
2
  # Creation date: 2024-01-10
3
 
4
  import torch
5
+ # import torch_directml
6
  from transformers import pipeline
7
 
8
  def getDevice(DEVICE):
 
13
  elif DEVICE == "cuda":
14
  device = torch.device("cuda")
15
  dtype = torch.float16
16
+ # elif DEVICE == "directml":
17
+ # device = torch_directml.device()
18
+ # dtype = torch.float16
19
  return device
20
 
21
  def loadGenerator(device):
 
30
  generator.tokenizer.save_pretrained("cache")
31
  generator.model.save_pretrained("cache")
32
  del generator
33
+ # if DEVICE == "directml":
34
+ # torch_directml.empty_cache()
requirements.txt CHANGED
@@ -1,3 +1,2 @@
1
  torch
2
- torch_directml
3
  transformers
 
1
  torch
 
2
  transformers